In [2]:
import math
import torch
import numpy as np

In [3]:
d_output_embedding = 32
d_input_embedding = 32
n_sequence_words = 10
n_head = 8

n_words = 1000

embedding_layer = torch.nn.Linear(n_words, d_input_embedding)
qkv_projection_layer = torch.nn.Linear(d_input_embedding, 3 * d_output_embedding)
mask_matrix = torch.triu(torch.ones(n_sequence_words, n_sequence_words)  * float('-inf'), diagonal=1)

In [4]:
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
# Positional encoding
position = torch.arange(n_sequence_words).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_output_embedding, 2) * (-math.log(10000.0) / d_output_embedding))
pe = torch.zeros(n_sequence_words, d_output_embedding)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

In [5]:
"""
    self attention function from the transformer model
    Q, K, V: (batch_size, n_head, n_sequence_words, d_output_embedding)
    mask: (n_sequence_words, n_sequence_words), None if no mask
    return: (batch_size, n_head, n_sequence_words, d_output_embedding) for the output of the self-attention
            (batch_size, n_head, n_sequence_words, n_sequence_words) for the attention matrix

    Example in matrix with batch size 2, n_head 3, n_sequence_words 2, d_output_embedding 3:
    Q (or K, V) = [
            // batch 1
            [
                // head 1
                [
                    // word 1
                    [1, 2, 3],
                    // word 2 
                    [4, 5, 6]
                ],
                // head 2
                [
                    [7, 8, 9],
                    [10, 11, 12]
                ],
            ], 
            // batch 2
            [
                [
                    [13, 14, 15],
                    [16, 17, 18]
                ],
                [
                    [19, 20, 21],
                    [22, 23, 24]
                ]
            ], 
            [
                [
                    [13, 14, 15], 
                    [16, 17, 18]
                ],
                [
                    [19, 20, 21],
                    [22, 23, 24]
                ]
            ]
        ]
"""
def self_attention(Q, K, V, mask = None):
    # Q, K, V: (batch_size, n_head, n_sequence_words, d_output_embedding)
    scale = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_output_embedding) # (batch_size, n_head, n_sequence_words, n_sequence_words)
    if mask is not None:
        scale = scale + mask
    attention = torch.nn.functional.softmax(scale, dim=-1)
    return torch.matmul(attention, V), attention # (batch_size, n_head, n_sequence_words, d_output_embedding) for the output of the self-attention


def Encoder(X, mask = None):
    # X: (batch_size, n_sequence_words, d_input_embedding)

    # Add positional encoding
    X += pe

    # Q, K, V projection
    qkv = qkv_projection_layer(X)
    batch_size, n_sequence_words, _ = qkv.shape
    qkv = qkv.reshape(batch_size, n_sequence_words, n_head, -1) # (batch_size, n_sequence_words, n_head, 3 * d_output_embedding for q + k + v)
    qkv = qkv.permute(0, 2, 1, 3) # (batch_size, n_head, n_sequence_words, 3 * d_output_embedding for q + k + v) The calculation is per head
    q, k, v = qkv.chunk(3, dim=-1) # (batch_size, n_head, n_sequence_words, d_output_embedding) for q, k, v
    
    return self_attention(q, k, v, mask)

In [6]:
X = torch.randn(8, n_sequence_words, d_input_embedding)

In [8]:
Encoder(X, None)[1][0][0]

tensor([[0.1167, 0.0865, 0.0755, 0.1049, 0.0823, 0.1172, 0.1027, 0.1184, 0.1077,
         0.0883],
        [0.0993, 0.1018, 0.1092, 0.1050, 0.1002, 0.0944, 0.1016, 0.0927, 0.0960,
         0.0998],
        [0.1075, 0.0850, 0.0703, 0.0971, 0.0916, 0.1137, 0.0983, 0.1220, 0.1154,
         0.0992],
        [0.1253, 0.0728, 0.0656, 0.0916, 0.0960, 0.1171, 0.0954, 0.1102, 0.1330,
         0.0931],
        [0.1804, 0.0719, 0.1017, 0.0876, 0.0892, 0.1233, 0.0967, 0.0708, 0.1198,
         0.0585],
        [0.1660, 0.0745, 0.1105, 0.0926, 0.0971, 0.1104, 0.0979, 0.0665, 0.1186,
         0.0660],
        [0.1782, 0.0637, 0.0916, 0.1001, 0.0918, 0.1099, 0.0993, 0.0680, 0.1306,
         0.0669],
        [0.1287, 0.0991, 0.1459, 0.0896, 0.1041, 0.1005, 0.0973, 0.0669, 0.0952,
         0.0727],
        [0.1688, 0.0673, 0.1163, 0.0951, 0.1086, 0.0957, 0.0961, 0.0543, 0.1274,
         0.0703],
        [0.1299, 0.0906, 0.1234, 0.0825, 0.1141, 0.1043, 0.0936, 0.0716, 0.1109,
         0.0793]], grad_fn=<

In [9]:
Encoder(X, mask_matrix)[1][0][0]

RuntimeError: Boolean value of Tensor with more than one value is ambiguous