### Scaled Dot Product Attn mechanism

In [1]:
import torch
import torch.nn.functional as F

- Create index of an input text

In [7]:
sentence = 'Life is too short to be afraid'

w_ix = {s:i for i, s in enumerate(sentence.split())}

In [8]:
w_ix

{'Life': 0, 'is': 1, 'too': 2, 'short': 3, 'to': 4, 'be': 5, 'afraid': 6}

In [10]:
sent_int = torch.tensor([w_ix[s] for s in sentence.replace(',', '').split()])
sent_int

tensor([0, 1, 2, 3, 4, 5, 6])

#### Create Embeddings for each word

In [16]:
torch.manual_seed(42)
embed = torch.nn.Embedding(7, 16)
embedded_sentence = embed(sent_int).detach()

print(embedded_sentence.shape)

torch.Size([7, 16])


Query, Key and Value vectors
- Assumed to be of same dim here

In [17]:
d = embedded_sentence.shape[1]
U_query = torch.rand(d, d)
U_key = torch.rand(d, d)
U_value = torch.rand(d, d)

Attn vectors

In [18]:
x_2 = embedded_sentence[2]
q_2 = U_query.matmul(x_2)
k_2 = U_key.matmul(x_2)
v_2 = U_value.matmul(x_2)

In [19]:
print(q_2.shape)

torch.Size([16])


In [20]:
keys = U_key.matmul(embedded_sentence.T).T
values = U_key.matmul(embedded_sentence.T).T

print(keys.shape)
print(values.shape)

torch.Size([7, 16])
torch.Size([7, 16])


Unnormalized attn weights

In [21]:
omg_2 = q_2.matmul(keys.T)
print(omg_2)

tensor([   7.2290, -163.6516,   85.0492,   32.2506,  -51.0580,    5.9121,
         -42.7785])


Compute Attn Scores
- Apply softmax
- Scale by 1/sqrt(dk) (dk = d here)

In [24]:
attn_weights_2 = F.softmax(omg_2/d**0.5, dim=0)
print(attn_weights_2)

tensor([3.5545e-09, 9.9459e-28, 1.0000e+00, 1.8512e-06, 1.6685e-15, 2.5573e-09,
        1.3222e-14])


Finally, compute Context Vector: atten weighted vector for input x[2]

In [25]:
context_2 = attn_weights_2.matmul(values)
print(context_2)

tensor([-2.6558, -2.1991, -1.7659, -2.7137, -2.1071, -2.3546, -3.4613, -2.4942,
        -3.1013, -3.5964, -1.6744, -2.1051, -2.5072, -2.5132, -1.1317, -3.7972])


#### Multi-Headed Attn

In [27]:
head = 3

multihead_U_query = torch.rand(head, d, d)
multihead_U_key = torch.rand(head, d, d)
multihead_U_value = torch.rand(head, d, d)

In [28]:
multihead_query_2 = multihead_U_query.matmul(x_2)
print(multihead_query_2)

tensor([[-1.5047, -1.7224, -3.2717, -1.0175, -3.1253, -3.4689, -1.0962, -2.1339,
         -1.5737, -3.9808, -3.3262, -2.3460, -1.2702,  0.7139, -0.4141, -3.3865],
        [-2.1503, -1.2775, -1.4216, -3.0825, -0.4842, -5.5525, -2.6342, -1.1775,
         -2.0682, -3.9901, -0.6912, -3.3454, -3.1575, -0.9453, -2.6569, -2.6200],
        [-2.2139, -0.1555, -1.1641, -1.1115, -2.9687, -1.7173, -2.8947, -2.9682,
         -2.1936, -3.2098, -2.8411, -0.4260, -3.3345, -2.7292, -3.0808, -3.1608]])


In [29]:
multihead_key_2 = multihead_U_key.matmul(x_2)
multihead_value_2 = multihead_U_value.matmul(x_2)

In [32]:
 embedded_sentence.T.shape

torch.Size([16, 7])

In [30]:
stacked_input = embedded_sentence.T.repeat(3, 1, 1)
print(stacked_input.shape)

torch.Size([3, 16, 7])


In [33]:
multihead_keys = torch.bmm(multihead_U_key, stacked_input)
multihead_values = torch.bmm(multihead_U_value, stacked_input)
print(multihead_keys.shape)
print(multihead_values.shape)

torch.Size([3, 16, 7])
torch.Size([3, 16, 7])
