In [127]:
import torch
import math

In [128]:
d_model = 20
h = 4
d_k = int(d_model / h)
d_v = d_k
attention_scaling = 1/(math.sqrt(d_k))

In [129]:
embedding = torch.randn(50,d_model)
embedding

tensor([[-7.5277e-01,  1.9036e-01, -6.2000e-01,  1.4756e+00,  6.7161e-01,
         -2.3163e+00,  4.3996e-02,  1.3492e+00, -2.4367e-01, -1.3864e+00,
         -6.5649e-02, -2.0385e+00, -2.1538e+00, -1.2207e+00,  9.0934e-01,
          3.2480e-01, -1.3104e+00,  5.6545e-02,  5.1561e-01, -9.2656e-01],
        [-1.5554e+00, -1.6995e+00, -9.4175e-01,  8.2635e-01,  5.5439e-02,
          2.9529e-01,  9.8002e-01, -1.1846e+00, -1.9780e+00,  5.3823e-01,
         -1.0980e+00,  9.0322e-01,  1.9762e+00,  2.2644e-01,  1.5777e+00,
          1.3795e+00, -4.3789e-01,  1.1376e+00,  7.1678e-01, -1.2568e+00],
        [ 5.1045e-01, -2.2478e-01, -3.4924e-01,  4.4048e-01,  3.0018e-01,
         -1.0981e+00, -4.4090e-01,  1.4912e+00,  5.9819e-01, -1.4095e+00,
          1.9491e+00,  1.1545e+00,  6.0267e-01, -3.6708e-01,  5.6265e-01,
         -5.7065e-01, -8.8206e-02,  9.1255e-01,  6.0625e-01, -1.9273e-01],
        [ 1.7246e+00,  3.1544e-01,  6.4304e-01, -4.7046e-01, -4.0708e-01,
         -3.2228e-01, -2.3516e-01, 

In [130]:
W_O = torch.randn(h*d_v, d_model, requires_grad=True)
#tuples with weights in order Q, K, V for each head
head_weights = []
for i in range(h):
    W_Q = torch.randn(d_model, d_k, requires_grad=True)
    W_K = torch.randn(d_model, d_k, requires_grad=True)
    W_V = torch.randn(d_model, d_v, requires_grad=True)
    head_weights.append((W_Q, W_K, W_V))

In [131]:
def attention_mask(input_):
    mask = torch.tril(input_, diagonal=0)
    return mask.masked_fill(mask == 0, float('-inf'))

In [132]:
def attention(Q, K, V):
    y_1 = Q @ K.t()
    
    y_2 = attention_scaling * y_1
    
    y_3 = attention_mask(y_2)
    
    max_y_3 = torch.max(y_3, 0, keepdim=True)[0]
    exp_softmax = torch.exp(y_3-max_y_3)
    sum_softmax = torch.sum(exp_softmax, 0, keepdim=True)
    y_4 = exp_softmax/sum_softmax

    y_5 = y_4 @ V
    
    return y_5


In [133]:
def multi_head_attention(E):
    heads = []
    for weights in head_weights:
        Q_W = weights[0]
        K_W = weights[1]
        V_W = weights[2]

        Q = E @ Q_W
        K = E @ K_W
        V = E @ V_W
        heads.append(attention(Q, K, V))
    return torch.cat(heads, dim=1) @ W_O

In [134]:
multi_head_attention(embedding)

tensor([[-6.3130e-06, -3.2151e-06, -2.1697e-05,  1.1792e-05,  3.4103e-05,
          2.0517e-06,  1.0079e-05,  6.1983e-05, -2.8253e-05, -1.2564e-05,
          2.9015e-05,  7.3703e-06,  3.4342e-05,  5.6553e-06, -1.0107e-06,
          4.4664e-06, -2.8171e-05, -1.6720e-05,  9.5449e-07, -2.4227e-05],
        [ 1.5433e-02, -4.1019e-03,  6.0796e-03,  4.3010e-03,  1.0571e-02,
          2.2937e-02,  2.7945e-04, -2.5491e-02, -2.0477e-02, -6.7027e-03,
          1.1919e-05, -1.8642e-02, -1.0792e-03,  6.9744e-03, -5.2731e-03,
         -5.3128e-03, -5.8192e-03, -1.3912e-02, -2.0577e-02, -7.2959e-03],
        [-1.8094e-04, -8.0707e-05, -6.6539e-04,  3.6861e-04,  1.0335e-03,
          6.9005e-05,  3.0524e-04,  1.8749e-03, -8.7235e-04, -3.9182e-04,
          8.7943e-04,  1.9852e-04,  1.0326e-03,  1.7848e-04, -3.4392e-05,
          1.5313e-04, -8.5069e-04, -5.2291e-04,  1.0047e-05, -7.2980e-04],
        [-7.9737e-08, -3.8734e-08, -2.7818e-07,  1.5532e-07,  4.3289e-07,
          3.1093e-08,  1.2338e-07, 