<img src="images/00-image.png" alt="encoder" class="bg-primary" width="100%">


[Image Reference](https://www.planetware.com/tourist-attractions-/potsdam-d-br-pt.htm)

<h1><center> Attention Explained <center></h1>

Vision Transformer (ViT) paper: [Paper Reference](https://arxiv.org/abs/2010.11929)

In [1]:
import numpy as np

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms

## Self Attention

In [3]:
# we need to find a way to parameterize each token so that we can rank them based on thier importance

In [4]:
torch.manual_seed(42)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [5]:
x[0]

tensor([[ 1.9269,  1.4873],
        [ 0.9007, -2.1055],
        [ 0.6784, -1.2345],
        [-0.0431, -1.6047],
        [-0.7521,  1.6487],
        [-0.3925, -1.4036],
        [-0.7279, -0.5594],
        [-0.7688,  0.7624]])

## Version 1 : Using simple mathematical function

In [6]:
attention1 = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        x_previous = x[b, :t+1]
        attention1[b, t] = torch.mean(x_previous, 0) # also called bow: bag of words / pixel or bop in out case

# bow is an averaging schemes, thus used as an attention

In [7]:
x[0, :1+1]

tensor([[ 1.9269,  1.4873],
        [ 0.9007, -2.1055]])

In [8]:
attention1[0, 1]

tensor([ 1.4138, -0.3091])

In [9]:
x[0]

tensor([[ 1.9269,  1.4873],
        [ 0.9007, -2.1055],
        [ 0.6784, -1.2345],
        [-0.0431, -1.6047],
        [-0.7521,  1.6487],
        [-0.3925, -1.4036],
        [-0.7279, -0.5594],
        [-0.7688,  0.7624]])

In [10]:
attention1[0]

tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])

### Replicating above by combining Matmul and Trill

In [11]:
# Explaining concept of matmul - example
a = torch.ones(3,3)
print(a)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])


In [12]:
b = torch.randint(0, 10, (3, 2)).float()
print(b)

tensor([[0., 1.],
        [3., 0.],
        [1., 1.]])


In [13]:
c = a@b
print(c)

tensor([[4., 2.],
        [4., 2.],
        [4., 2.]])


In [14]:
# Explaining concept of Tril - example
torch.tril(torch.ones(8, 8))

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [15]:
# Combining tril and matmul
torch.manual_seed(42)
a = torch.tril(torch.ones(8, 8))
b = torch.randint(0, 10, (8, 2)).float()
c = a@b

print(a)
print(b)
print(c)

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.],
        [0., 4.],
        [0., 3.],
        [8., 4.],
        [0., 4.],
        [1., 2.]])
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.],
        [14., 20.],
        [14., 23.],
        [22., 27.],
        [22., 31.],
        [23., 33.]])


In [16]:
# average and using tril : basically make everything we multiply with it average themselves
torch.manual_seed(42)
a = torch.tril(torch.ones(8,8))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0, 10, (8, 2)).float()
c = a@b

print(a)
print(b)
print(c)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.],
        [0., 4.],
        [0., 3.],
        [8., 4.],
        [0., 4.],
        [1., 2.]])
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333],
        [3.5000, 5.0000],
        [2.8000, 4.6000],
        [3.6667, 4.5000],
        [3.1429, 4.4286],
        [2.8750, 4.1250]])


### Introducing and Establishing the SCORES / WEIGHT

In [17]:
# Now appling it
tril = torch.tril(torch.ones(T,T))

In [18]:
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [19]:
scores = tril / torch.sum(tril, 1, keepdim=True)
scores

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [20]:
attention2 = scores @ x

In [21]:
torch.allclose(attention2[0], attention1[0])

True

In [22]:
attention1[0]

tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])

In [23]:
attention2[0]

tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])

- Now we have found a way to parameterize them such that we have our weight (average) separated

#

## Version 2 : Using softmax - another method but better

- The weight is designed differently using softmax
- Though here softmax would do the same thing as above by evenly distributing the weight aggregate
- However, that is the case where we use 0 and 1, hence when we use weight that is data dependent (QKV in softmax), softmax would be better then

In [24]:
tril = torch.tril(torch.ones(T,T)) # because 8 tokens must result into 8X8
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [25]:
scores = torch.zeros(T,T)

In [26]:
scores = scores.masked_fill(tril == 0, float('-inf'))

In [27]:
scores

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [28]:
p_attn = torch.softmax(scores, dim=-1)
p_attn

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [29]:
attention3 = p_attn @ x

In [30]:
attention3[0]

tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])

In [31]:
torch.allclose(attention2[0], attention3[0])

True

#

## Version 3: Final with self attention (Putting everything together)

- Here, the weight will be data dependent, thus making softmax very useful 
- The idea is, unlIke previous version, we do not want the values to be context average, otherwise it suggests that all token are equally important

Note below: In terms of image, we are no longer talking about 2D image, thus, arrangement of Channel, Width, Height no longer matter, everything is now 1D. Hence,
- B = batch, in image = channel
- T = time/Sentence, in image = flattend R/G/B 
- C = Depth dimension, can be any value you wish e.g RGB 

In [32]:
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)

- Note: the QKV is for the weight initialization, hence must come out in block_size/context lenght size
- The K and Q are the same values, but by using the transpose, each and every token can multiply its Key to all Query of the others
- Then we can estimate its affinity

In [33]:
head_size = 3 # 16 # this is the 
key = nn.Linear(in_features=C, out_features=head_size)
query = nn.Linear(in_features=C, out_features=head_size)
value = nn.Linear(in_features=C, out_features=head_size)
k = key(x) # shape = (B, T, 3)
q = query(x) # shape = (B, T, 3)

In [34]:
k.shape

torch.Size([4, 8, 3])

In [35]:
k.transpose(-2, -1).shape

torch.Size([4, 3, 8])

In [36]:
matmul_qk = q@k.transpose(-2, -1)

In [37]:
matmul_qk.shape

torch.Size([4, 8, 8])

In [38]:
scores = matmul_qk * k.size(-1)**-0.5 # Scale Factor

In [39]:
scores

tensor([[[ 0.1586, -0.3868, -0.5070, -0.2084, -0.0627, -0.2190, -0.2884,
          -0.3706],
         [-0.0637, -0.2880, -0.2969, -0.0997, -0.0577, -0.1326, -0.2706,
          -0.2463],
         [-0.1438, -0.2684, -0.2761, -0.1715, -0.1470, -0.1879, -0.2572,
          -0.2476],
         [-0.0794, -0.3266, -0.4379, -0.4069, -0.3157, -0.3716, -0.2495,
          -0.3684],
         [-0.0062, -0.3520, -0.4828, -0.3935, -0.2770, -0.3618, -0.2584,
          -0.3889],
         [-0.0617, -0.3231, -0.4159, -0.3372, -0.2518, -0.3175, -0.2558,
          -0.3457],
         [-0.0057, -0.3045, -0.3203, -0.0646, -0.0069, -0.1058, -0.2793,
          -0.2523],
         [-0.0840, -0.2928, -0.3252, -0.1857, -0.1360, -0.1994, -0.2629,
          -0.2748]],

        [[-0.6574,  0.0109, -0.2844, -0.3501, -0.3486, -0.5219, -0.0957,
          -0.5622],
         [-0.2279, -0.3581, -0.1941, -0.1711, -0.3596, -0.2271, -0.4246,
          -0.2480],
         [-0.5419, -0.2499, -0.1904, -0.2011, -0.5337, -0.4346, -0.4

- The above has made our initial weight to be data driven
- Now we need to truncate using trill below

In [40]:
tril = torch.tril(torch.ones(T,T))
scores = scores.masked_fill(tril == 0, float('-inf')) # Comment out to obtain bi-directional effect
scores[0]

tensor([[ 0.1586,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.0637, -0.2880,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.1438, -0.2684, -0.2761,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.0794, -0.3266, -0.4379, -0.4069,    -inf,    -inf,    -inf,    -inf],
        [-0.0062, -0.3520, -0.4828, -0.3935, -0.2770,    -inf,    -inf,    -inf],
        [-0.0617, -0.3231, -0.4159, -0.3372, -0.2518, -0.3175,    -inf,    -inf],
        [-0.0057, -0.3045, -0.3203, -0.0646, -0.0069, -0.1058, -0.2793,    -inf],
        [-0.0840, -0.2928, -0.3252, -0.1857, -0.1360, -0.1994, -0.2629, -0.2748]],
       grad_fn=<SelectBackward0>)

In [41]:
p_attn = torch.softmax(scores, dim=-1)
p_attn[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5558, 0.4442, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3625, 0.3200, 0.3176, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3125, 0.2440, 0.2183, 0.2252, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2652, 0.1877, 0.1647, 0.1801, 0.2023, 0.0000, 0.0000, 0.0000],
        [0.2070, 0.1593, 0.1452, 0.1571, 0.1711, 0.1602, 0.0000, 0.0000],
        [0.1645, 0.1220, 0.1201, 0.1551, 0.1643, 0.1488, 0.1251, 0.0000],
        [0.1428, 0.1159, 0.1122, 0.1290, 0.1356, 0.1272, 0.1194, 0.1180]],
       grad_fn=<SelectBackward0>)

In [42]:
p_attn.shape

torch.Size([4, 8, 8])

In [43]:
v = value(x) # we aggregate the values not the exact token, it is also learnable
v.shape

torch.Size([4, 8, 3])

In [44]:
attention = p_attn @ v

In [45]:
attention.shape

torch.Size([4, 8, 3])

## Puthing all together

In [46]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, embed_dim, head_size, dropout):
        super().__init__()
        self.key = nn.Linear(embed_dim, head_size, bias=False)
        self.query = nn.Linear(embed_dim, head_size, bias=False)
        self.value = nn.Linear(embed_dim, head_size, bias=False)
        #self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value):
        # input of size (batch, latent-space, feature map)
        # output of size (batch, latent-space, head size)
        
        B,T,C = key.shape
        
        key = self.key(key)   # (B,T,hs)
        query = self.query(query) # (B,T,hs)
        
        # compute attention scores ("affinities")
        matmul_qk = query @ key.transpose(-2,-1) # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        
        scores = matmul_qk * key.size(-1)**-0.5 # Scale Factor

        scores = scores.masked_fill(torch.tril(torch.ones(T,T)) == 0, float('-inf'))# (B, T, T) # Comment out to obtain bi-directional effect
        
        p_attn = F.softmax(scores, dim=-1) # (B, T, T)
        
        p_attn = self.dropout(p_attn)
        
        # perform the weighted aggregation of the values
        value = self.value(value) # (B,T,hs)
        attention = p_attn @ value # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        
        return attention