In [76]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
    

In [77]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

### Simulate multi head Q, K, V

In [78]:
batch_size, head, seq_len, d_k = 2, 4, 3, 2
shape = (batch_size, head, seq_len, d_k)
q = torch.randn(shape)
k = torch.randn(shape)
v = torch.randn(shape)

q.shape, k.shape, v.shape

(torch.Size([2, 4, 3, 2]), torch.Size([2, 4, 3, 2]), torch.Size([2, 4, 3, 2]))

In [79]:
d_model = head * d_k
d_model

8

In [80]:
q[0]

tensor([[[ 1.4516, -0.7912],
         [ 0.6224, -0.6406],
         [ 0.9762, -0.2074]],

        [[-0.4497,  0.0510],
         [-0.4176,  0.2446],
         [ 0.9037,  1.4893]],

        [[ 0.2951,  0.4999],
         [-0.5982,  1.5882],
         [-0.3139, -2.0337]],

        [[ 0.4014,  0.5351],
         [ 0.4024,  0.2968],
         [ 0.4885,  0.8046]]])

In [81]:
q.view(batch_size, -1, d_model)

tensor([[[ 1.4516, -0.7912,  0.6224, -0.6406,  0.9762, -0.2074, -0.4497,
           0.0510],
         [-0.4176,  0.2446,  0.9037,  1.4893,  0.2951,  0.4999, -0.5982,
           1.5882],
         [-0.3139, -2.0337,  0.4014,  0.5351,  0.4024,  0.2968,  0.4885,
           0.8046]],

        [[-1.7663, -1.0167, -0.4292,  1.1023,  1.2906, -0.8918, -0.4826,
           1.1973],
         [ 0.3470, -0.4257,  1.1276,  0.3564, -0.0685,  0.6590,  1.5181,
          -1.5887],
         [-0.2452,  1.7298, -0.7800,  1.6039,  0.5995,  1.1992,  1.0092,
           0.1009]]])

In [82]:
linear_layer = nn.Linear(d_model, d_model)

linear_layer

Linear(in_features=8, out_features=8, bias=True)

### After Linear Projection

In [83]:
w_q = linear_layer(q.view(batch_size, -1, d_model))
w_k = linear_layer(k.view(batch_size, -1, d_model))
w_v = linear_layer(v.view(batch_size, -1, d_model))

w_q.shape, w_k.shape, w_v.shape

(torch.Size([2, 3, 8]), torch.Size([2, 3, 8]), torch.Size([2, 3, 8]))

In [84]:
w_q

tensor([[[ 0.5124, -0.1605,  0.7902, -0.0840,  0.2493, -0.1333, -0.6790,
           0.1412],
         [ 0.5485,  0.6827, -0.6042,  1.0182,  0.6211, -0.8116, -0.2984,
           0.3787],
         [ 0.3681,  0.7303, -0.8663, -0.0063,  0.5300, -0.1834, -0.8838,
          -0.2012]],

        [[ 0.7192,  1.5630, -0.8405,  1.5888,  0.0951, -0.3763, -0.1109,
           0.6873],
         [-1.1264,  0.1490,  0.2615, -1.0148, -0.3084,  0.6448, -0.9948,
          -1.0468],
         [-0.3558,  0.9942, -0.1582,  0.8780, -0.4400, -0.0886, -0.1796,
          -0.8628]]], grad_fn=<ViewBackward0>)

### Reshape

In [85]:
w_q_r = w_q.view(batch_size, -1, head, d_k).transpose(1,2)
w_k_r = w_k.view(batch_size, -1, head, d_k).transpose(1,2)
w_v_r = w_v.view(batch_size, -1, head, d_k).transpose(1,2)

w_q_r.shape, w_k_r.shape, w_v_r.shape

(torch.Size([2, 4, 3, 2]), torch.Size([2, 4, 3, 2]), torch.Size([2, 4, 3, 2]))

In [86]:
w_q_r[0]

tensor([[[ 0.5124, -0.1605],
         [ 0.5485,  0.6827],
         [ 0.3681,  0.7303]],

        [[ 0.7902, -0.0840],
         [-0.6042,  1.0182],
         [-0.8663, -0.0063]],

        [[ 0.2493, -0.1333],
         [ 0.6211, -0.8116],
         [ 0.5300, -0.1834]],

        [[-0.6790,  0.1412],
         [-0.2984,  0.3787],
         [-0.8838, -0.2012]]], grad_fn=<SelectBackward0>)

In [87]:
values, attention = scaled_dot_product(w_q_r, w_k_r, w_v_r)

values.shape, attention.shape

(torch.Size([2, 4, 3, 2]), torch.Size([2, 4, 3, 3]))

In [88]:
values[0]

tensor([[[-0.9383,  0.6550],
         [-0.8682,  0.6854],
         [-0.8406,  0.6796]],

        [[-0.5437,  0.2594],
         [-0.5576,  0.1968],
         [-0.3841,  0.1295]],

        [[-0.4620,  0.7077],
         [-0.2423,  0.5695],
         [-0.3570,  0.6486]],

        [[-0.1075, -0.6816],
         [-0.0758, -0.6241],
         [-0.1374, -0.7357]]], grad_fn=<SelectBackward0>)

In [89]:
attention[0]

tensor([[[0.3172, 0.3968, 0.2861],
         [0.2640, 0.4031, 0.3329],
         [0.2694, 0.3795, 0.3511]],

        [[0.2294, 0.3250, 0.4457],
         [0.3024, 0.4633, 0.2343],
         [0.4749, 0.3017, 0.2234]],

        [[0.2816, 0.3492, 0.3692],
         [0.1932, 0.3555, 0.4513],
         [0.2342, 0.3642, 0.4016]],

        [[0.3537, 0.3123, 0.3340],
         [0.2909, 0.4228, 0.2863],
         [0.4128, 0.2085, 0.3787]]], grad_fn=<SelectBackward0>)