<a href="https://colab.research.google.com/github/arjunravi26/transformer-implementation/blob/main/attn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [102]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F

Data Preparation

In [8]:
sentence = 'The quick brown fox jumps over a lazy dog'
dc = {s: i for i, s in enumerate(sorted(sentence.split()))}
print(dc)

{'The': 0, 'a': 1, 'brown': 2, 'dog': 3, 'fox': 4, 'jumps': 5, 'lazy': 6, 'over': 7, 'quick': 8}


In [9]:
r = [dc[i] for i in sentence.split()]
sentence_int = torch.tensor(r)
print(sentence_int)

tensor([0, 8, 2, 4, 5, 7, 1, 6, 3])


Embedding

In [25]:
vocab_size = 50000
torch.manual_seed(42)
embed = nn.Embedding(vocab_size,3)
embed

Embedding(50000, 3)

In [22]:
embeddings = embed(sentence_int).detach()
embeddings.shape

torch.Size([9, 3])

In [27]:
embeddings

tensor([[ 1.9269,  1.4873,  0.9007],
        [ 1.2791,  1.2964,  0.6105],
        [-0.0431, -1.6047, -0.7521],
        [-0.7279, -0.5594, -0.7688],
        [ 0.7624,  1.6423, -0.1596],
        [ 1.0783,  0.8008,  1.6806],
        [-2.1055,  0.6784, -1.2345],
        [-0.4974,  0.4396, -0.7581],
        [ 1.6487, -0.3925, -1.4036]])

In [28]:
weights = embed.weight

In [31]:
sentence_int

tensor([0, 8, 2, 4, 5, 7, 1, 6, 3])

In [33]:
weights[1]

tensor([-2.1055,  0.6784, -1.2345], grad_fn=<SelectBackward0>)

In [36]:
torch.embedding(weights,torch.tensor([0,8]))

tensor([[1.9269, 1.4873, 0.9007],
        [1.2791, 1.2964, 0.6105]], grad_fn=<EmbeddingBackward0>)

Self attention mechanism

In [38]:
embeddings.shape

torch.Size([9, 3])

In [37]:
torch.manual_seed(42)
d = embeddings.shape[1]
d

3

Set dimension for query, key and value

In [41]:
d_q, d_k, d_v = 2, 2, 4

Initialize weights for query, key and value

In [64]:
W_query = nn.Parameter(data=torch.randn(size=(d,d_q)))
W_key = nn.Parameter(data=torch.randn(size=(d,d_k)))
W_value = nn.Parameter(data=torch.randn(size=(d,d_v)))

In [65]:
W_query, W_key, W_value

(Parameter containing:
 tensor([[ 0.1498, -0.2089],
         [-0.3870,  0.9912],
         [ 0.4679, -0.2049]], requires_grad=True),
 Parameter containing:
 tensor([[-0.7409,  0.3618],
         [ 1.9199, -0.2254],
         [-0.3417,  0.3040]], requires_grad=True),
 Parameter containing:
 tensor([[-0.6890, -1.1267, -0.2858, -1.0935],
         [ 1.1351,  0.7592, -3.5945,  0.0192],
         [ 0.1052,  0.9603, -0.5672, -0.5706]], requires_grad=True))

In [66]:
embeddings

tensor([[ 1.9269,  1.4873,  0.9007],
        [ 1.2791,  1.2964,  0.6105],
        [-0.0431, -1.6047, -0.7521],
        [-0.7279, -0.5594, -0.7688],
        [ 0.7624,  1.6423, -0.1596],
        [ 1.0783,  0.8008,  1.6806],
        [-2.1055,  0.6784, -1.2345],
        [-0.4974,  0.4396, -0.7581],
        [ 1.6487, -0.3925, -1.4036]])

In [69]:
query = embeddings @ W_query
key = embeddings @ W_key
value = embeddings @ W_value

In [72]:
query

tensor([[ 0.1345,  0.8871],
        [-0.0245,  0.8928],
        [ 0.2627, -1.4275],
        [-0.2523, -0.2449],
        [-0.5961,  1.5014],
        [ 0.6380,  0.2241],
        [-1.1556,  1.3653],
        [-0.5994,  0.6950],
        [-0.2579, -0.4458]], grad_fn=<MmBackward0>)

In [71]:
key

tensor([[ 1.1200,  0.6358],
        [ 1.3327,  0.3562],
        [-2.7919,  0.1174],
        [-0.2720, -0.3710],
        [ 2.6427, -0.1428],
        [ 0.1643,  0.7206],
        [ 3.2843, -1.2901],
        [ 1.4715, -0.5095],
        [-1.4955,  0.2583]], grad_fn=<MmBackward0>)

In [70]:
value

tensor([[ 0.4553, -0.1769, -6.4075, -2.5926],
        [ 0.6544,  0.1293, -5.3717, -1.7223],
        [-1.8709, -1.8921,  6.2068,  0.4456],
        [-0.2144, -0.3430,  2.6549,  1.2240],
        [ 1.3221,  0.2346, -6.0306, -0.7112],
        [ 0.3428,  1.0070, -4.1398, -2.1229],
        [ 2.0910,  1.7018, -1.1367,  3.0199],
        [ 0.7620,  0.1661, -1.0080,  0.9850],
        [-1.7291, -3.5036,  1.7357, -1.0094]], grad_fn=<MmBackward0>)

Calculating Attention Scores

In [75]:
query.shape

torch.Size([9, 2])

In [76]:
key.shape

torch.Size([9, 2])

In [91]:
query

tensor([[ 0.1345,  0.8871],
        [-0.0245,  0.8928],
        [ 0.2627, -1.4275],
        [-0.2523, -0.2449],
        [-0.5961,  1.5014],
        [ 0.6380,  0.2241],
        [-1.1556,  1.3653],
        [-0.5994,  0.6950],
        [-0.2579, -0.4458]], grad_fn=<MmBackward0>)

In [92]:
key

tensor([[ 1.1200,  0.6358],
        [ 1.3327,  0.3562],
        [-2.7919,  0.1174],
        [-0.2720, -0.3710],
        [ 2.6427, -0.1428],
        [ 0.1643,  0.7206],
        [ 3.2843, -1.2901],
        [ 1.4715, -0.5095],
        [-1.4955,  0.2583]], grad_fn=<MmBackward0>)

In [87]:
query.shape, key.T.shape

(torch.Size([9, 2]), torch.Size([2, 9]))

In [78]:
qk = query @ key.T
qk

tensor([[ 0.7147,  0.4952, -0.2712, -0.3657,  0.2286,  0.6614, -0.7029, -0.2542,
          0.0281],
        [ 0.5402,  0.2854,  0.1733, -0.3246, -0.1923,  0.6393, -1.2323, -0.4910,
          0.2673],
        [-0.6135, -0.1585, -0.9010,  0.4582,  0.8980, -0.9855,  2.7042,  1.1139,
         -0.7615],
        [-0.4383, -0.4234,  0.6755,  0.1595, -0.6317, -0.2179, -0.5125, -0.2464,
          0.3140],
        [ 0.2870, -0.2596,  1.8405, -0.3949, -1.7897,  0.9840, -3.8946, -1.6422,
          1.2793],
        [ 0.8570,  0.9300, -1.7548, -0.2567,  1.6539,  0.2663,  1.8061,  0.8246,
         -0.8961],
        [-0.4261, -1.0537,  3.3866, -0.1922, -3.2488,  0.7940, -5.5566, -2.3962,
          2.0808],
        [-0.2294, -0.5512,  1.7550, -0.0948, -1.6832,  0.4024, -2.8651, -1.2361,
          1.0759],
        [-0.5723, -0.5025,  0.6677,  0.2356, -0.6179, -0.3636, -0.2719, -0.1524,
          0.2705]], grad_fn=<MmBackward0>)

In [99]:
math.sqrt(d_k)

1.4142135623730951

In [98]:
qk = qk // math.sqrt(d_k)

In [100]:
qk

tensor([[ 0.,  0., -1., -1.,  0.,  0., -1., -1.,  0.],
        [ 0.,  0.,  0., -1., -1.,  0., -1., -1.,  0.],
        [-1., -1., -1.,  0.,  0., -1.,  1.,  0., -1.],
        [-1., -1.,  0.,  0., -1., -1., -1., -1.,  0.],
        [ 0., -1.,  1., -1., -2.,  0., -3., -2.,  0.],
        [ 0.,  0., -2., -1.,  1.,  0.,  1.,  0., -1.],
        [-1., -1.,  2., -1., -3.,  0., -4., -2.,  1.],
        [-1., -1.,  1., -1., -2.,  0., -3., -1.,  0.],
        [-1., -1.,  0.,  0., -1., -1., -1., -1.,  0.]],
       grad_fn=<NotImplemented>)

In [107]:
F.softmax(qk,dim=-1)

tensor([[0.1545, 0.1545, 0.0568, 0.0568, 0.1545, 0.1545, 0.0568, 0.0568, 0.1545],
        [0.1545, 0.1545, 0.1545, 0.0568, 0.0568, 0.1545, 0.0568, 0.0568, 0.1545],
        [0.0487, 0.0487, 0.0487, 0.1323, 0.1323, 0.0487, 0.3597, 0.1323, 0.0487],
        [0.0706, 0.0706, 0.1920, 0.1920, 0.0706, 0.0706, 0.0706, 0.0706, 0.1920],
        [0.1476, 0.0543, 0.4013, 0.0543, 0.0200, 0.1476, 0.0073, 0.0200, 0.1476],
        [0.0970, 0.0970, 0.0131, 0.0357, 0.2637, 0.0970, 0.2637, 0.0970, 0.0357],
        [0.0296, 0.0296, 0.5952, 0.0296, 0.0040, 0.0806, 0.0015, 0.0109, 0.2190],
        [0.0577, 0.0577, 0.4264, 0.0577, 0.0212, 0.1569, 0.0078, 0.0577, 0.1569],
        [0.0706, 0.0706, 0.1920, 0.1920, 0.0706, 0.0706, 0.0706, 0.0706, 0.1920]],
       grad_fn=<SoftmaxBackward0>)

In [108]:
attn_score = F.softmax(qk,dim=-1)

Now we obtain attn score, So we want to calculate context vector.

In [114]:
attn_score.shape,value.shape

(torch.Size([9, 9]), torch.Size([9, 4]))

In [112]:
context_vector = attn_score @ value
context_vector

tensor([[ 0.2052, -0.3777, -2.7417, -0.9381],
        [-0.1067, -0.5855, -1.5464, -0.8251],
        [ 0.8949,  0.4038, -1.3771,  0.9435],
        [-0.3349, -0.8857,  0.3329, -0.0954],
        [-0.8072, -1.1449,  0.8933, -0.6657],
        [ 1.0210,  0.4577, -3.2941,  0.0933],
        [-1.4213, -1.8185,  3.4337, -0.2061],
        [-0.8752, -1.1931,  1.5479, -0.4144],
        [-0.3349, -0.8857,  0.3329, -0.0954]], grad_fn=<MmBackward0>)

In [113]:
context_vector.shape

torch.Size([9, 4])