<img src="images/00-image.png" alt="encoder" class="bg-primary" width="100%">


[Image Reference](https://www.planetware.com/tourist-attractions-/potsdam-d-br-pt.htm)

<h1><center> Attention Explained <center></h1>

Vision Transformer (ViT) paper: [Paper Reference](https://arxiv.org/abs/2010.11929)

In [1]:
import numpy as np

In [3]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms

## Self Attention

In [4]:
# we need to find a way to parameterize each token so that we can rank them based on their importance

In [5]:
torch.manual_seed(42)
B, T, C = 4, 8, 2 # Batch, Time Dim, Channels
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [6]:
x[0]

tensor([[ 1.9269,  1.4873],
        [ 0.9007, -2.1055],
        [ 0.6784, -1.2345],
        [-0.0431, -1.6047],
        [-0.7521,  1.6487],
        [-0.3925, -1.4036],
        [-0.7279, -0.5594],
        [-0.7688,  0.7624]])

## Version 1 : Using simple mathematical function

In [7]:
attention1 = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        x_previous = x[b, :t+1]
        attention1[b, t] = torch.mean(x_previous, 0) # also called bow: bag of words / pixel or bop in out case

# bow is an averaging schemes, thus used as an attention

In [8]:
x[0, :3+1]

tensor([[ 1.9269,  1.4873],
        [ 0.9007, -2.1055],
        [ 0.6784, -1.2345],
        [-0.0431, -1.6047]])

In [9]:
attention1[0, :3]

tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176]])

### Replicating above by combining Matmul and Trill

In [10]:
# Explaining concept of matmul - example
a = torch.ones(3,3)
print(a)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])


In [11]:
b = torch.randint(0, 10, (3, 2)).float()
print(b)

tensor([[0., 1.],
        [3., 0.],
        [1., 1.]])


In [12]:
c = a@b
print(c)

tensor([[4., 2.],
        [4., 2.],
        [4., 2.]])


In [13]:
# Explaining concept of Tril - example
torch.tril(torch.ones(8, 8))

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [14]:
# Combining tril and matmul
torch.manual_seed(42)
a = torch.tril(torch.ones(8, 8))
b = torch.randint(0, 10, (8, 2)).float()
c = a@b

print(a)
print(b)
print(c)

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.],
        [0., 4.],
        [0., 3.],
        [8., 4.],
        [0., 4.],
        [1., 2.]])
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.],
        [14., 20.],
        [14., 23.],
        [22., 27.],
        [22., 31.],
        [23., 33.]])


In [16]:
# average and using tril : basically make everything we multiply with it average themselves
torch.manual_seed(42)
a = torch.tril(torch.ones(8,8))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0, 10, (8, 2)).float()
c = a@b

print(a)
print(b)
print(c)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.],
        [0., 4.],
        [0., 3.],
        [8., 4.],
        [0., 4.],
        [1., 2.]])
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333],
        [3.5000, 5.0000],
        [2.8000, 4.6000],
        [3.6667, 4.5000],
        [3.1429, 4.4286],
        [2.8750, 4.1250]])


### Introducing and Establishing the SCORES / WEIGHT

In [15]:
# Now appling it
tril = torch.tril(torch.ones(T,T))

In [16]:
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [18]:
scores = tril / torch.sum(tril, axis = 1, keepdim=True)
scores

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [19]:
attention2 = scores @ x

In [20]:
torch.allclose(attention2[0], attention1[0])

True

In [21]:
attention1[0]

tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])

In [22]:
attention2[0]

tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])

- Now we have found a way to parameterize them such that we have our weight (average) separated

#

## Version 2 : Using softmax - another method but better

- The weight is designed differently using softmax
- Though here softmax would do the same thing as above by evenly distributing the weight aggregate
- However, that is the case where we use 0 and 1, hence when we use weight that is data dependent (QKV in softmax), softmax would be better then

In [23]:
tril = torch.tril(torch.ones(T,T)) # because 8 tokens must result into 8X8
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [38]:
scores = torch.zeros(T,T)

In [39]:
scores = scores.masked_fill(tril == 0, float('-inf'))

In [40]:
scores

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [41]:
p_attn = torch.softmax(scores, dim=-1)
p_attn

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [42]:
attention3 = p_attn @ x

In [43]:
attention3[0]

tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])

In [44]:
torch.allclose(attention2[0], attention3[0])

True

#

## Version 3: Final with self attention (Putting everything together)

- Here, the weight will be data dependent, thus making softmax very useful 
- The idea is, unlIke previous version, we do not want the values to be context average, otherwise it suggests that all token are equally important

Note below: In terms of image, we are no longer talking about 2D image, thus, arrangement of Channel, Width, Height no longer matter, everything is now 1D. Hence,
- B = batch, in image = channel
- T = time/Sentence, in image = flattend R/G/B 
- C = Depth dimension, can be any value you wish e.g RGB 

In [45]:
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)

- Note: the QKV is for the weight initialization, hence must come out in block_size/context lenght size
- The K and Q are the same values, but by using the transpose, each and every token can multiply its Key to all Query of the others
- Then we can estimate its affinity

In [46]:
head_size = 3 # 16 # this is the 
key = nn.Linear(in_features=C, out_features=head_size)
query = nn.Linear(in_features=C, out_features=head_size)
value = nn.Linear(in_features=C, out_features=head_size)
k = key(x) # shape = (B, T, 3)
q = query(x) # shape = (B, T, 3)

In [47]:
k.shape

torch.Size([4, 8, 3])

In [53]:
k.transpose(-2, -1).shape

torch.Size([4, 3, 8])

In [54]:
matmul_qk = q@k.transpose(-2, -1)

In [55]:
matmul_qk.shape

torch.Size([4, 8, 8])

In [59]:
matmul_qk[0]

tensor([[-0.3184, -0.7780, -0.0200,  0.0947, -0.1123, -0.5016, -0.9347,  0.2224],
        [-0.5950, -1.0813,  0.2122, -0.0400, -0.5541, -0.0613, -1.3578,  0.3490],
        [-0.2704, -0.9691,  0.0344,  0.3219,  0.0966, -0.7693, -1.1739,  0.4391],
        [-0.1013, -0.5984, -0.1795,  0.2493,  0.2661, -0.8931, -0.6777,  0.1802],
        [-0.1469, -0.5021, -0.1981,  0.1047,  0.1139, -0.7057, -0.5598,  0.0578],
        [-0.6235, -1.4734,  0.3764,  0.2479, -0.4039, -0.2991, -1.8641,  0.7155],
        [-0.6597, -1.0974,  0.2452, -0.1175, -0.6872,  0.0848, -1.3864,  0.3249],
        [-0.0925, -0.7147, -0.1380,  0.3589,  0.3486, -1.0060, -0.8257,  0.2996]],
       grad_fn=<SelectBackward0>)

In [60]:
scores = matmul_qk * k.size(-1)**-0.5 # Scale Factor : Square root of the key size

In [63]:
scores[0]

tensor([[-0.1838,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.3435, -0.6243,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.1561, -0.5595,  0.0199,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.0585, -0.3455, -0.1036,  0.1439,    -inf,    -inf,    -inf,    -inf],
        [-0.0848, -0.2899, -0.1144,  0.0605,  0.0658,    -inf,    -inf,    -inf],
        [-0.3600, -0.8507,  0.2173,  0.1431, -0.2332, -0.1727,    -inf,    -inf],
        [-0.3809, -0.6336,  0.1416, -0.0678, -0.3968,  0.0489, -0.8004,    -inf],
        [-0.0534, -0.4126, -0.0797,  0.2072,  0.2012, -0.5808, -0.4767,  0.1729]],
       grad_fn=<SelectBackward0>)

- The above has made our initial weight to be data driven
- Now we need to truncate using trill below

In [64]:
tril = torch.tril(torch.ones(T,T))
scores = scores.masked_fill(tril == 0, float('-inf')) # Comment out to obtain bi-directional effect
scores[0]

tensor([[-0.1838,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.3435, -0.6243,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.1561, -0.5595,  0.0199,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.0585, -0.3455, -0.1036,  0.1439,    -inf,    -inf,    -inf,    -inf],
        [-0.0848, -0.2899, -0.1144,  0.0605,  0.0658,    -inf,    -inf,    -inf],
        [-0.3600, -0.8507,  0.2173,  0.1431, -0.2332, -0.1727,    -inf,    -inf],
        [-0.3809, -0.6336,  0.1416, -0.0678, -0.3968,  0.0489, -0.8004,    -inf],
        [-0.0534, -0.4126, -0.0797,  0.2072,  0.2012, -0.5808, -0.4767,  0.1729]],
       grad_fn=<SelectBackward0>)

In [65]:
p_attn = torch.softmax(scores, dim=-1)
p_attn[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5697, 0.4303, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3496, 0.2335, 0.4169, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2544, 0.1909, 0.2432, 0.3115, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1959, 0.1596, 0.1902, 0.2265, 0.2278, 0.0000, 0.0000, 0.0000],
        [0.1353, 0.0829, 0.2411, 0.2238, 0.1536, 0.1632, 0.0000, 0.0000],
        [0.1249, 0.0970, 0.2105, 0.1708, 0.1229, 0.1919, 0.0821, 0.0000],
        [0.1289, 0.0900, 0.1255, 0.1673, 0.1663, 0.0761, 0.0844, 0.1616]],
       grad_fn=<SelectBackward0>)

In [66]:
p_attn.shape

torch.Size([4, 8, 8])

In [67]:
v = value(x) # we aggregate the values not the exact token, it is also learnable
v.shape

torch.Size([4, 8, 3])

In [68]:
attention = p_attn @ v

In [69]:
attention.shape

torch.Size([4, 8, 3])

## Puthing all together

In [46]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, embed_dim, head_size, dropout):
        super().__init__()
        self.key = nn.Linear(embed_dim, head_size, bias=False)
        self.query = nn.Linear(embed_dim, head_size, bias=False)
        self.value = nn.Linear(embed_dim, head_size, bias=False)
        #self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value):
        # input of size (batch, latent-space, feature map)
        # output of size (batch, latent-space, head size)
        
        B,T,C = key.shape
        
        key = self.key(key)   # (B,T,hs)
        query = self.query(query) # (B,T,hs)
        
        # compute attention scores ("affinities")
        matmul_qk = query @ key.transpose(-2,-1) # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        
        scores = matmul_qk * key.size(-1)**-0.5 # Scale Factor

        scores = scores.masked_fill(torch.tril(torch.ones(T,T)) == 0, float('-inf'))# (B, T, T) # Comment out to obtain bi-directional effect
        
        p_attn = F.softmax(scores, dim=-1) # (B, T, T)
        
        p_attn = self.dropout(p_attn)
        
        # perform the weighted aggregation of the values
        value = self.value(value) # (B,T,hs)
        attention = p_attn @ value # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        
        return attention