In [57]:
import math
import torch
import numpy as np

In [49]:
d_output_embedding = 32
d_input_embedding = 32
n_sequence_words = 10
n_head = 8

n_words = 1000

embedding_layer = torch.nn.Linear(n_words, d_input_embedding)
qkv_projection_layer = torch.nn.Linear(d_input_embedding, 3 * d_output_embedding)
mask_matrix = torch.triu(torch.ones(n_sequence_words, n_sequence_words)  * float('-inf'), diagonal=1)

In [65]:
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
position = torch.arange(n_sequence_words).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_output_embedding, 2) * (-math.log(10000.0) / d_output_embedding))
pe = torch.zeros(n_sequence_words, d_output_embedding)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

In [66]:
def self_attention(X, mask = False):
    X += pe
    qkv = qkv_projection_layer(X)
    batch_size, n_sequence_words, _ = qkv.shape
    qkv = qkv.reshape(batch_size, n_sequence_words, n_head, -1) # (batch_size, n_sequence_words, n_head, 3 * d_output_embedding for q + k + v)
    qkv = qkv.permute(0, 2, 1, 3) # (batch_size, n_head, n_sequence_words, 3 * d_output_embedding for q + k + v) The calculation is per head

    q, k, v = qkv.chunk(3, dim=-1) # (batch_size, n_head, n_sequence_words, d_output_embedding) for q, k, v
    scale = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(d_output_embedding) # (batch_size, n_head, n_sequence_words, n_sequence_words)
    if mask:
        scale = scale + mask_matrix

    attention = torch.nn.functional.softmax(scale, dim=-1)
    return torch.matmul(attention, v), attention # (batch_size, n_head, n_sequence_words, d_output_embedding) for the output of the self-attention

In [67]:
X = torch.randn(8, n_sequence_words, d_input_embedding)

In [68]:
self_attention(X, True)[1][0][0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.4760, 0.5240, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.3123, 0.2832, 0.4046, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2163, 0.2637, 0.2844, 0.2356, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1672, 0.2236, 0.2188, 0.1913, 0.1990, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1348, 0.1886, 0.2283, 0.1643, 0.1514, 0.1326, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1336, 0.1440, 0.1573, 0.1363, 0.1282, 0.1462, 0.1544, 0.0000, 0.0000,
         0.0000],
        [0.1212, 0.1312, 0.1512, 0.1356, 0.1025, 0.1240, 0.1341, 0.1001, 0.0000,
         0.0000],
        [0.0951, 0.1461, 0.1313, 0.1272, 0.1034, 0.0903, 0.0938, 0.1096, 0.1031,
         0.0000],
        [0.0705, 0.1555, 0.1275, 0.1087, 0.0791, 0.0793, 0.0853, 0.1099, 0.0888,
         0.0955]], grad_fn=<

In [64]:
pe

tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,
          0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,
          0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,
          0.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  5.3317e-01,  8.4601e-01,  3.1098e-01,
          9.5042e-01,  1.7689e-01,  9.8423e-01,  9.9833e-02,  9.9500e-01,
          5.6204e-02,  9.9842e-01,  3.1618e-02,  9.9950e-01,  1.7782e-02,
          9.9984e-01,  9.9998e-03,  9.9995e-01,  5.6234e-03,  9.9998e-01,
          3.1623e-03,  9.9999e-01,  1.7783e-03,  1.0000e+00,  1.0000e-03,
          1.0000e+00,  5.6234e-04,  1.0000e+00,  3.1623e-04,  1.0000e+00,
          1.7783e-04,  1.0000e+00],
        [ 9.0930e-01, -4.1615e-01,  9.02