<a href="https://colab.research.google.com/github/arjunravi26/transformer-implementation/blob/main/attn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F

Data Preparation

In [3]:
sentence = 'The quick brown fox jumps over a lazy dog'
dc = {s: i for i, s in enumerate(sorted(sentence.split()))}
print(dc)

{'The': 0, 'a': 1, 'brown': 2, 'dog': 3, 'fox': 4, 'jumps': 5, 'lazy': 6, 'over': 7, 'quick': 8}


In [4]:
r = [dc[i] for i in sentence.split()]
sentence_int = torch.tensor(r)
print(sentence_int)

tensor([0, 8, 2, 4, 5, 7, 1, 6, 3])


Embedding

In [5]:
vocab_size = 50000
torch.manual_seed(42)
embed = nn.Embedding(vocab_size,3)
embed

Embedding(50000, 3)

In [6]:
embeddings = embed(sentence_int).detach()
embeddings.shape

torch.Size([9, 3])

In [7]:
embeddings

tensor([[ 1.9269,  1.4873,  0.9007],
        [ 1.2791,  1.2964,  0.6105],
        [-0.0431, -1.6047, -0.7521],
        [-0.7279, -0.5594, -0.7688],
        [ 0.7624,  1.6423, -0.1596],
        [ 1.0783,  0.8008,  1.6806],
        [-2.1055,  0.6784, -1.2345],
        [-0.4974,  0.4396, -0.7581],
        [ 1.6487, -0.3925, -1.4036]])

In [8]:
weights = embed.weight

In [9]:
sentence_int

tensor([0, 8, 2, 4, 5, 7, 1, 6, 3])

In [10]:
weights[1]

tensor([-2.1055,  0.6784, -1.2345], grad_fn=<SelectBackward0>)

In [11]:
torch.embedding(weights,torch.tensor([0,8]))

tensor([[1.9269, 1.4873, 0.9007],
        [1.2791, 1.2964, 0.6105]], grad_fn=<EmbeddingBackward0>)

Self attention mechanism

In [12]:
embeddings.shape

torch.Size([9, 3])

In [13]:
torch.manual_seed(42)
d = embeddings.shape[1]
d

3

Set dimension for query, key and value

In [14]:
d_q, d_k, d_v = 2, 2, 4

Initialize weights for query, key and value

In [15]:
W_query = nn.Parameter(data=torch.randn(size=(d,d_q)))
W_key = nn.Parameter(data=torch.randn(size=(d,d_k)))
W_value = nn.Parameter(data=torch.randn(size=(d,d_v)))

In [16]:
W_query, W_key, W_value

(Parameter containing:
 tensor([[ 0.3367,  0.1288],
         [ 0.2345,  0.2303],
         [-1.1229, -0.1863]], requires_grad=True),
 Parameter containing:
 tensor([[ 2.2082, -0.6380],
         [ 0.4617,  0.2674],
         [ 0.5349,  0.8094]], requires_grad=True),
 Parameter containing:
 tensor([[ 1.1103, -1.6898, -0.9890,  0.9580],
         [ 1.3221,  0.8172, -0.7658, -0.7506],
         [ 1.3525,  0.6863, -0.3278,  0.7950]], requires_grad=True))

In [17]:
embeddings

tensor([[ 1.9269,  1.4873,  0.9007],
        [ 1.2791,  1.2964,  0.6105],
        [-0.0431, -1.6047, -0.7521],
        [-0.7279, -0.5594, -0.7688],
        [ 0.7624,  1.6423, -0.1596],
        [ 1.0783,  0.8008,  1.6806],
        [-2.1055,  0.6784, -1.2345],
        [-0.4974,  0.4396, -0.7581],
        [ 1.6487, -0.3925, -1.4036]])

In [18]:
query = embeddings @ W_query
key = embeddings @ W_key
value = embeddings @ W_value

In [19]:
query

tensor([[-0.0139,  0.4229],
        [ 0.0492,  0.3496],
        [ 0.4538, -0.2350],
        [ 0.4871, -0.0794],
        [ 0.8210,  0.5062],
        [-1.3363,  0.0102],
        [ 0.8364,  0.1151],
        [ 0.7869,  0.1784],
        [ 2.0391,  0.3835]], grad_fn=<MmBackward0>)

In [20]:
key

tensor([[ 5.4234, -0.1027],
        [ 3.7496,  0.0246],
        [-1.2382, -1.0103],
        [-2.2768, -0.3074],
        [ 2.3565, -0.1765],
        [ 3.6498,  0.8864],
        [-4.9966,  0.5255],
        [-1.3009, -0.1787],
        [ 2.7087, -2.2928]], grad_fn=<MmBackward0>)

In [21]:
value

tensor([[ 5.3241, -1.4225, -3.3399,  1.4456],
        [ 3.9599, -0.6831, -2.4579,  0.7375],
        [-3.1867, -1.7547,  1.5180,  0.5653],
        [-2.5877,  0.2451,  1.4003, -0.8886],
        [ 2.8020, -0.0558, -1.9595, -0.6292],
        [ 4.5291, -0.0143, -2.2305,  1.7679],
        [-3.1106,  3.2650,  1.9673, -3.5077],
        [-0.9965,  0.6794,  0.4037, -1.4091],
        [-0.5868, -4.0701, -0.8699,  0.7582]], grad_fn=<MmBackward0>)

Calculating Attention Scores

In [22]:
query.shape

torch.Size([9, 2])

In [23]:
key.shape

torch.Size([9, 2])

In [24]:
query

tensor([[-0.0139,  0.4229],
        [ 0.0492,  0.3496],
        [ 0.4538, -0.2350],
        [ 0.4871, -0.0794],
        [ 0.8210,  0.5062],
        [-1.3363,  0.0102],
        [ 0.8364,  0.1151],
        [ 0.7869,  0.1784],
        [ 2.0391,  0.3835]], grad_fn=<MmBackward0>)

In [25]:
key

tensor([[ 5.4234, -0.1027],
        [ 3.7496,  0.0246],
        [-1.2382, -1.0103],
        [-2.2768, -0.3074],
        [ 2.3565, -0.1765],
        [ 3.6498,  0.8864],
        [-4.9966,  0.5255],
        [-1.3009, -0.1787],
        [ 2.7087, -2.2928]], grad_fn=<MmBackward0>)

In [26]:
query.shape, key.T.shape

(torch.Size([9, 2]), torch.Size([2, 9]))

In [27]:
qk = query @ key.T
qk

tensor([[-0.1188, -0.0417, -0.4101, -0.0984, -0.1074,  0.3242,  0.2917, -0.0575,
         -1.0074],
        [ 0.2307,  0.1930, -0.4141, -0.2194,  0.0541,  0.4893, -0.0619, -0.1265,
         -0.6685],
        [ 2.4853,  1.6958, -0.3245, -0.9610,  1.1109,  1.4480, -2.3910, -0.5484,
          1.7681],
        [ 2.6497,  1.8243, -0.5229, -1.0846,  1.1617,  1.7073, -2.4753, -0.6195,
          1.5013],
        [ 4.4005,  3.0908, -1.5280, -2.0249,  1.8452,  3.4451, -3.8360, -1.1585,
          1.0631],
        [-7.2483, -5.0103,  1.6443,  3.0393, -3.1507, -4.8681,  6.6822,  1.7366,
         -3.6430],
        [ 4.5242,  3.1389, -1.1519, -1.9397,  1.9506,  3.1546, -4.1185, -1.1086,
          2.0016],
        [ 4.2492,  2.9548, -1.1546, -1.8464,  1.8227,  3.0301, -3.8379, -1.0556,
          1.7223],
        [11.0197,  7.6554, -2.9124, -4.7607,  4.7374,  7.7824, -9.9872, -2.7213,
          4.6442]], grad_fn=<MmBackward0>)

In [28]:
math.sqrt(d_k)

1.4142135623730951

In [29]:
qk = qk // math.sqrt(d_k)

In [30]:
qk

tensor([[-1., -1., -1., -1., -1.,  0.,  0., -1., -1.],
        [ 0.,  0., -1., -1.,  0.,  0., -1., -1., -1.],
        [ 1.,  1., -1., -1.,  0.,  1., -2., -1.,  1.],
        [ 1.,  1., -1., -1.,  0.,  1., -2., -1.,  1.],
        [ 3.,  2., -2., -2.,  1.,  2., -3., -1.,  0.],
        [-6., -4.,  1.,  2., -3., -4.,  4.,  1., -3.],
        [ 3.,  2., -1., -2.,  1.,  2., -3., -1.,  1.],
        [ 3.,  2., -1., -2.,  1.,  2., -3., -1.,  1.],
        [ 7.,  5., -3., -4.,  3.,  5., -8., -2.,  3.]],
       grad_fn=<NotImplemented>)

In [31]:
F.softmax(qk,dim=-1)

tensor([[8.0408e-02, 8.0408e-02, 8.0408e-02, 8.0408e-02, 8.0408e-02, 2.1857e-01,
         2.1857e-01, 8.0408e-02, 8.0408e-02],
        [1.7125e-01, 1.7125e-01, 6.3000e-02, 6.3000e-02, 1.7125e-01, 1.7125e-01,
         6.3000e-02, 6.3000e-02, 6.3000e-02],
        [2.0731e-01, 2.0731e-01, 2.8056e-02, 2.8056e-02, 7.6265e-02, 2.0731e-01,
         1.0321e-02, 2.8056e-02, 2.0731e-01],
        [2.0731e-01, 2.0731e-01, 2.8056e-02, 2.8056e-02, 7.6265e-02, 2.0731e-01,
         1.0321e-02, 2.8056e-02, 2.0731e-01],
        [5.1147e-01, 1.8816e-01, 3.4463e-03, 3.4463e-03, 6.9220e-02, 1.8816e-01,
         1.2678e-03, 9.3679e-03, 2.5465e-02],
        [3.6688e-05, 2.7109e-04, 4.0234e-02, 1.0937e-01, 7.3690e-04, 2.7109e-04,
         8.0811e-01, 4.0234e-02, 7.3690e-04],
        [4.8726e-01, 1.7925e-01, 8.9245e-03, 3.2832e-03, 6.5944e-02, 1.7925e-01,
         1.2078e-03, 8.9245e-03, 6.5944e-02],
        [4.8726e-01, 1.7925e-01, 8.9245e-03, 3.2832e-03, 6.5944e-02, 1.7925e-01,
         1.2078e-03, 8.9245e-0

In [32]:
attn_score = F.softmax(qk,dim=-1)

Now we obtain attn score, So we want to calculate context vector.

In [33]:
attn_score.shape,value.shape

(torch.Size([9, 9]), torch.Size([9, 4]))

In [34]:
context_vector = attn_score @ value
context_vector

tensor([[ 0.6903,  0.1427, -0.4841, -0.3336],
        [ 2.1859, -0.4756, -1.4320,  0.2865],
        [ 2.7336, -1.2771, -1.8806,  0.8435],
        [ 2.7336, -1.2771, -1.8806,  0.8435],
        [ 4.4662, -0.9610, -2.7319,  1.1678],
        [-2.9609,  2.6188,  1.8168, -2.9649],
        [ 4.2124, -1.0951, -2.6303,  1.1473],
        [ 4.2124, -1.0951, -2.6303,  1.1473],
        [ 4.9815, -1.2180, -3.0792,  1.3666]], grad_fn=<MmBackward0>)

In [35]:
context_vector.shape

torch.Size([9, 4])

Self Attention

In [36]:
class SelfAttention(nn.Module):
  def __init__(self,d_model,d_key,d_value):
    super(SelfAttention,self).__init__()
    self.d_model = d_model
    self.d_key = d_key
    self.d_value = d_value
    self.W_key = nn.Linear(d_model,d_key,bias=False)
    self.W_query = nn.Linear(d_model,d_key,bias=False)
    self.W_value = nn.Linear(d_model,d_value,bias=False)

  def forward(self,x):
    key = self.W_key(x)
    query = self.W_query(x)
    value = self.W_value(x)
    attn_weights = query @ key.transpose(-2,-1)
    attn_scores = attn_weights / math.sqrt(self.d_key)
    attn_scores = F.softmax(attn_scores,dim=-1)
    return attn_scores @ value


In [40]:
self_attn = SelfAttention(d_model=3,d_key=2,d_value=4)
attn = self_attn(x=embeddings)
attn

tensor([[-0.7420,  0.1583, -0.5563, -0.8465],
        [-0.6801,  0.1724, -0.4873, -0.7122],
        [-0.2891,  0.2603, -0.0530,  0.1516],
        [-0.2930,  0.2554, -0.0605,  0.1223],
        [-0.6066,  0.1796, -0.4133, -0.5800],
        [-0.7146,  0.1792, -0.5138, -0.7475],
        [-0.2163,  0.2682,  0.0214,  0.2617],
        [-0.3786,  0.2363, -0.1553, -0.0706],
        [-0.4685,  0.2278, -0.2457, -0.2298]], grad_fn=<MmBackward0>)

Multihead attn

In [88]:
class MultiHeadAttention(nn.Module):
  def __init__(self,d_model,num_heads):
    super(MultiHeadAttention,self).__init__()
    assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
    self.d_model = d_model
    self.d_key = d_model // num_heads
    self.num_heads = num_heads

    self.W_key = nn.Linear(d_model,d_model,bias=False)
    self.W_value = nn.Linear(d_model,d_model,bias=False)
    self.W_query = nn.Linear(d_model,d_model,bias=False)

    self.W_output = nn.Linear(d_model,d_model,bias=False)

  def forward(self,x):
    B,S,_ = x.shape
    key = self.W_key(x)
    query = self.W_query(x)
    value = self.W_value(x)

    Q = query.view(B,S,self.num_heads,self.d_key).transpose(2,1)
    K = key.view(B,S,self.num_heads,self.d_key).transpose(1,2)
    V = value.view(B,S,self.num_heads,self.d_key).transpose(2,1)

    scores = Q @ K.transpose(-2,-1)
    scores = scores / math.sqrt(self.d_key)
    attn = F.softmax(scores,dim=-1)

    out = attn @ V
    out = out.transpose(2,1).contiguous()
    out = out.view(B,S,self.d_model)

    return self.W_output(out)



In [53]:
sentence

'The quick brown fox jumps over a lazy dog'

In [57]:
vocab_size = 1000
embed_dim = 6
embed = nn.Embedding(vocab_size,embed_dim)
embed

Embedding(1000, 6)

In [61]:
dc = {w:i for i,w in enumerate(sentence.split())}
dc

{'The': 0,
 'quick': 1,
 'brown': 2,
 'fox': 3,
 'jumps': 4,
 'over': 5,
 'a': 6,
 'lazy': 7,
 'dog': 8}

In [63]:
sentence_int = [dc[v] for v in dc]
sentence_int

[0, 1, 2, 3, 4, 5, 6, 7, 8]

In [72]:
embeddings = embed(torch.tensor(sentence_int))
embeddings.shape

torch.Size([9, 6])

In [85]:
batched_embeddings = embeddings.reshape(1,embeddings.shape[0],embeddings.shape[1])
batched_embeddings

tensor([[[-1.3712,  1.0005,  0.3395,  1.2227,  0.1482, -1.1230],
         [-0.9876, -0.7730,  0.7633,  0.3511, -0.1280,  1.2385],
         [-0.3839, -0.1225,  0.1394,  0.1799, -0.3362,  0.9838],
         [ 1.1671,  0.3482, -1.5765,  0.4396,  0.7627, -1.3002],
         [-0.7500,  0.2921,  1.6024, -0.6984,  0.3449, -0.0996],
         [ 0.0329, -1.4311, -0.0211,  0.3831, -0.5715, -1.0143],
         [ 0.4313, -0.9067,  0.3586, -0.7071, -0.1030,  1.5218],
         [-0.6902,  0.3947, -0.9836,  0.9162, -0.4780,  0.0592],
         [ 0.6442, -0.6631,  0.5682,  0.2775, -0.1973, -1.1960]]],
       grad_fn=<ViewBackward0>)

In [74]:
batched_embeddings.shape

torch.Size([1, 9, 6])

In [90]:
mult_head_attn = MultiHeadAttention(d_model=embed_dim,num_heads=1)
attn = mult_head_attn(batched_embeddings)
attn.shape

torch.Size([1, 9, 6])

In [91]:
mult_head_attn = MultiHeadAttention(d_model=embed_dim,num_heads=2)
attn = mult_head_attn(batched_embeddings)
attn.shape

torch.Size([1, 9, 6])