In [1]:
import torch
import numpy
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms

## Self Attention

In [2]:
# we need to find a way to parameterize each token so that we can rank them based on thier importance

In [3]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [4]:
#x

## v1 : No weight bcos everything is done together

In [5]:
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]
        xbow[b, t] = torch.mean(xprev, 0) # mean in vertical direction

In [6]:
x[0, :1+1]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152]])

In [7]:
xbow[0, 1]

tensor([-0.0894, -0.4926])

In [8]:
xbow

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])

#


## v2 : Replicating above by combining Matmul and Trill

In [9]:
# Explaining concept of matmul - example
torch.manual_seed(42)
a = torch.ones(3,3)
b = torch.randint(0, 10, (3, 2)).float()
c = a@b

print(a)
print(b)
print(c)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


In [10]:
# Explaining concept of Tril - example
torch.tril(torch.ones(3, 3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [11]:
# Combining tril and matmul
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
b = torch.randint(0, 10, (3, 2)).float()
c = a@b

print(a)
print(b)
print(c)

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


In [12]:
# average and using tril : basically make everything we multiply with it average themselves
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a@b

print(a)
print(b)
print(c)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


--------------

In [13]:
# Now appling it
tril = torch.tril(torch.ones(T,T))

In [14]:
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [15]:
weight = tril / torch.sum(tril, 1, keepdim=True)
weight

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [16]:
xbow2 = weight @ x

In [17]:
torch.allclose(xbow2[0],xbow[0])

True

In [18]:
xbow2

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])

In [19]:
# Now we have found a way to parameterize them such that we have our weight separated

#

## v3 : Using softmax - another method but better

In [20]:
# - the weight is designed differently using softmax
# - Though in v3, softmax would do the same thing as the v2 by evenly distributing the weight aggregate
# - However, that is the case where we use 0 and 1, hence when we use weight that is data dependent (QKV in softmax), softmax would be better then, see V4

In [21]:
tril = torch.tril(torch.ones(T,T))
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [22]:
weight = torch.zeros(T,T)

In [23]:
weight = weight.masked_fill(tril == 0, float('-inf'))

In [24]:
weight

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [25]:
weight = torch.softmax(weight, dim=-1)
weight

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [26]:
xbow3 = weight @ x

In [27]:
xbow3

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])

In [28]:
torch.allclose(xbow2[0], xbow[0])

True

#

## V4 : Final with self attention

In [29]:
# - Here, the weight will be data dependent, thus making softmax very useful 
# - The idea is, unlke prev version, we do not want the values to be uniform, otherwise it suggsts that all token are equally important

In [30]:
# Note below: In terms of image, we are no longer talking about 2D image, thus, arrangement of Channel, Width, Height no longer matter, everything is now 1D
# Hence B = batch, in image = channel
#       T = time/Sentence, in image = flattend R/G/B 

In [31]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)

In [32]:
# self attention starts here --- single head
# Note: the QKV thing is for the weight initialization, hence must come out in block_size/context lenght size
# the K and Q are the same values, but by using the transpose, each and every token can multiply its Key to all Query of the others
# then we can estimate its affinity

head_size = 16
key = nn.Linear(C, head_size, bias=False) # same as query
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # shape = (B, T, 16)
q = query(x) # shape = (B, T, 16)

In [33]:
k.shape

torch.Size([4, 8, 16])

In [34]:
k.transpose(-2, -1).shape

torch.Size([4, 16, 8])

In [35]:
weight = q@k.transpose(-2, -1)

In [36]:
weight.shape

torch.Size([4, 8, 8])

In [37]:
weight

tensor([[[ 5.1839e-02, -1.2717e-01,  1.8555e-01,  2.8362e-01,  1.3338e-01,
          -4.1065e-01,  6.4280e-02, -2.3709e-01],
         [-2.7879e-02, -2.6872e-01, -1.4179e-02, -1.3808e-02,  3.4634e-01,
          -1.0934e-01, -8.2021e-02,  5.0689e-01],
         [ 1.6034e-01, -3.0773e-01,  5.5218e-01,  8.4200e-01,  3.0638e-01,
          -1.1863e+00,  2.1087e-01, -8.2965e-01],
         [ 2.4276e-01, -4.5681e-01,  8.3370e-01,  1.2711e+00,  4.5259e-01,
          -1.7872e+00,  3.2054e-01, -1.2664e+00],
         [ 1.0243e-02,  3.9295e-01, -6.9511e-02, -1.1600e-01, -4.9214e-01,
           3.2836e-01,  7.1550e-02, -5.1735e-01],
         [-3.1340e-01,  4.3864e-01, -1.0379e+00, -1.5788e+00, -3.9689e-01,
           2.1592e+00, -4.3509e-01,  1.8049e+00],
         [ 7.8257e-02, -2.3943e-01,  2.9217e-01,  4.4768e-01,  2.6021e-01,
          -6.6641e-01,  9.0359e-02, -3.0451e-01],
         [-3.4883e-01,  1.2351e+00, -1.3450e+00, -2.0646e+00, -1.3681e+00,
           3.1349e+00, -3.7915e-01,  1.1684e+00]],

In [38]:
# The above has made our initial weight to be data driven
# Now we need to truncate using trill below

In [39]:
tril = torch.tril(torch.ones(T,T))
#weight = torch.zeros(T,T)
weight = weight.masked_fill(tril == 0, float('-inf'))
weight[0]

tensor([[ 0.0518,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.0279, -0.2687,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.1603, -0.3077,  0.5522,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.2428, -0.4568,  0.8337,  1.2711,    -inf,    -inf,    -inf,    -inf],
        [ 0.0102,  0.3929, -0.0695, -0.1160, -0.4921,    -inf,    -inf,    -inf],
        [-0.3134,  0.4386, -1.0379, -1.5788, -0.3969,  2.1592,    -inf,    -inf],
        [ 0.0783, -0.2394,  0.2922,  0.4477,  0.2602, -0.6664,  0.0904,    -inf],
        [-0.3488,  1.2351, -1.3450, -2.0646, -1.3681,  3.1349, -0.3791,  1.1684]],
       grad_fn=<SelectBackward0>)

In [40]:
weight = torch.softmax(weight, dim=-1)
weight[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5599, 0.4401, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3220, 0.2016, 0.4764, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1640, 0.0815, 0.2961, 0.4585, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2051, 0.3007, 0.1894, 0.1808, 0.1241, 0.0000, 0.0000, 0.0000],
        [0.0600, 0.1273, 0.0291, 0.0169, 0.0552, 0.7114, 0.0000, 0.0000],
        [0.1408, 0.1025, 0.1744, 0.2038, 0.1690, 0.0669, 0.1426, 0.0000],
        [0.0223, 0.1086, 0.0082, 0.0040, 0.0080, 0.7257, 0.0216, 0.1016]],
       grad_fn=<SelectBackward0>)

In [41]:
weight.shape

torch.Size([4, 8, 8])

In [42]:
v = value(x) # we aggregate the values not the exact token, it is also learnable
v.shape

torch.Size([4, 8, 16])

In [43]:
xbow3 = weight @ v

## Puthing all together

In [44]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, embed_dim, head_size, dropout):
        super().__init__()
        self.key = nn.Linear(embed_dim, head_size, bias=False)
        self.query = nn.Linear(embed_dim, head_size, bias=False)
        self.value = nn.Linear(embed_dim, head_size, bias=False)
        #self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value):
        # input of size (batch, latent-space, feature map)
        # output of size (batch, latent-space, head size)
        
        B,T,C = key.shape
        
        key = self.key(key)   # (B,T,hs)
        query = self.query(query) # (B,T,hs)
        
        # compute attention scores ("affinities")
        wei = query @ key.transpose(-2,-1) # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        
        wei = wei * key.shape[-1]**-0.5 # Scale Factor

        wei = wei.masked_fill(torch.tril(torch.ones(T,T)) == 0, float('-inf'))# (B, T, T)
        
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        
        wei = self.dropout(wei)
        
        # perform the weighted aggregation of the values
        value = self.value(value) # (B,T,hs)
        out = wei @ value # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        
        return out