In [1]:
import torch
import numpy as np





In [4]:
s = "Life is short, eat dessert first"
dc = {w:i for i,w in enumerate(sorted(s.replace(",","").split()))}
dc

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}

In [6]:
ts = torch.tensor([dc[w] for w in s.replace(",","").split()])
ts

tensor([0, 4, 5, 2, 1, 3])

In [16]:
vocab_size = 50000
torch.manual_seed(123)

embed = torch.nn.Embedding(vocab_size,3)
embedded_sentence = embed(ts).detach()

print(embedded_sentence,embedded_sentence.shape)


tensor([[ 0.3374, -0.1778, -0.1690],
        [-1.1589,  0.3255, -0.6315],
        [-2.8400, -0.7849, -1.4096],
        [ 1.2753, -0.2010, -0.1606],
        [ 0.9178,  1.5810,  1.3010],
        [-0.4015,  0.9666, -1.1481]]) torch.Size([6, 3])


#### Self attention mechanism



In [28]:
# Projection matrices
d = embedded_sentence.shape[1]
dk,dq,dv = 2,2,4

Wq = torch.nn.Parameter(torch.rand(d,dq))
Wk = torch.nn.Parameter(torch.rand(d,dk))
Wv = torch.nn.Parameter(torch.rand(d,dv))
Wq

Parameter containing:
tensor([[0.7577, 0.4536],
        [0.4130, 0.5585],
        [0.1170, 0.5578]], requires_grad=True)

##### Generating the Query vector for element 2

In [35]:
x_2 = embedded_sentence[1]
q_2 =  torch.matmul(x_2,Wq)
k_2 = torch.matmul(x_2,Wk)
v_2 = torch.matmul(x_2,Wv)
print(q_2,q_2.shape)
print(k_2,k_2.shape)
print(v_2,v_2.shape)

tensor([-0.8175, -0.6962], grad_fn=<SqueezeBackward4>) torch.Size([2])
tensor([-1.2935, -1.0338], grad_fn=<SqueezeBackward4>) torch.Size([2])
tensor([-1.2396, -0.0786, -0.9770, -0.7058], grad_fn=<SqueezeBackward4>) torch.Size([4])


In [37]:
q = embedded_sentence@Wq
k = embedded_sentence@Wk
v = embedded_sentence@Wv

print(q,q.shape)
print(k,k.shape)
print(v,v.shape)

tensor([[ 0.1624, -0.0405],
        [-0.8175, -0.6962],
        [-2.6408, -2.5129],
        [ 0.8645,  0.3767],
        [ 1.5005,  2.0251],
        [-0.0393, -0.2827]], grad_fn=<MmBackward0>) torch.Size([6, 2])
tensor([[-0.0047,  0.1438],
        [-1.2935, -1.0338],
        [-3.5769, -3.5702],
        [ 0.6223,  1.0003],
        [ 2.4583,  2.2977],
        [-1.0833, -0.0429]], grad_fn=<MmBackward0>) torch.Size([6, 2])
tensor([[ 0.1304, -0.0952,  0.1261,  0.0945],
        [-1.2396, -0.0786, -0.9770, -0.7058],
        [-3.9806, -1.5924, -2.8134, -2.7060],
        [ 1.0345,  0.1212,  0.7981,  0.7339],
        [ 2.5606,  1.8079,  1.6247,  1.8530],
        [-0.3506,  0.6221, -0.4360,  0.1402]], grad_fn=<MmBackward0>) torch.Size([6, 4])


Unnormalised attention weights, omega for q_2


In [42]:
print(q_2.shape)
print(k[0].shape)
q_2.dot(k[0])

torch.Size([2])
torch.Size([2])


tensor(-0.0962, grad_fn=<DotBackward0>)

In [48]:
Omega_2 = q_2@k.T
Omega_2.shape

torch.Size([6])

Normalize attention weights

In [50]:
import torch.nn.functional as F
attention_weights_2 = F.softmax(Omega_2/(dk**0.5))
attention_weights_2

  attention_weights_2 = F.softmax(Omega_2/(dk**0.5))


tensor([0.0177, 0.0667, 0.8698, 0.0081, 0.0015, 0.0363],
       grad_fn=<SoftmaxBackward0>)

In [57]:
print(v.shape)
print(attention_weights_2.shape)

torch.Size([6, 4])
torch.Size([6])


In [55]:
context_vector_2 = attention_weights_2@v

In [56]:
context_vector_2

tensor([-3.5431, -1.3658, -2.5169, -2.3852], grad_fn=<SqueezeBackward4>)

Class: SelfAttention

In [65]:
import torch.nn as nn
class SelfAttention(nn.Module):
    def __init__(self,d,dq,dk,dv):
        super().__init__()
        self.dk = dk
        self.dv = dv
        self.Wq = torch.nn.Parameter(torch.rand(d,dq))
        self.Wk = torch.nn.Parameter(torch.rand(d,dk))
        self.Wv = torch.nn.Parameter(torch.rand(d,dv))
    def forward(self,x):
        # x = embedded_sentence
        q = x@self.Wq
        k = x@self.Wk
        v = x@self.Wv

        # omega = Unnormalised attention
        omega = q@k.T

        attention_weights = torch.softmax(
            omega/self.dk**0.5
            ,dim = -1
            )
        context_vector =  attention_weights@v
        # (n X dv) where n = num_words
        return context_vector



        


In [66]:
torch.manual_seed(123)
d,dk,dq,dv = 3,2,2,4
a = SelfAttention(d,dk,dq,dv)
context_vectors = a(embedded_sentence)
print(context_vectors)

tensor([[-0.3556, -0.1204, -0.2910, -0.3997],
        [-0.8818, -0.9456, -1.0524, -1.5258],
        [-1.4612, -1.9542, -2.0118, -2.9217],
        [ 0.1046,  0.5374,  0.2862,  0.4735],
        [ 1.3173,  2.1931,  1.6684,  2.6063],
        [-0.6171, -0.5215, -0.6524, -0.9389]], grad_fn=<MmBackward0>)
