In [2]:
import math
import torch
import numpy as np

In [3]:
d_output_embedding = 32
d_input_embedding = 32
n_sequence_words = 10
n_head = 8

n_words = 1000

embedding_layer = torch.nn.Linear(n_words, d_input_embedding)
qkv_projection_layer = torch.nn.Linear(d_input_embedding, 3 * d_output_embedding)
mask_matrix = torch.triu(torch.ones(n_sequence_words, n_sequence_words)  * float('-inf'), diagonal=1)

In [4]:
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
# Positional encoding
position = torch.arange(n_sequence_words).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_output_embedding, 2) * (-math.log(10000.0) / d_output_embedding))
pe = torch.zeros(n_sequence_words, d_output_embedding)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

In [10]:
"""
    self attention function from the transformer model
    Q, K, V: (batch_size, n_head, n_sequence_words, d_output_embedding)
    mask: (n_sequence_words, n_sequence_words), None if no mask
    return: (batch_size, n_head, n_sequence_words, d_output_embedding) for the output of the self-attention
            (batch_size, n_head, n_sequence_words, n_sequence_words) for the attention matrix

    Example in matrix with batch size 2, n_head 3, n_sequence_words 2, d_output_embedding 3:
    Q (or K, V) = [
            // batch 1
            [
                // head 1
                [
                    // word 1
                    [1, 2, 3],
                    // word 2 
                    [4, 5, 6]
                ],
                // head 2
                [
                    [7, 8, 9],
                    [10, 11, 12]
                ],
            ], 
            // batch 2
            [
                [
                    [13, 14, 15],
                    [16, 17, 18]
                ],
                [
                    [19, 20, 21],
                    [22, 23, 24]
                ]
            ], 
            [
                [
                    [13, 14, 15], 
                    [16, 17, 18]
                ],
                [
                    [19, 20, 21],
                    [22, 23, 24]
                ]
            ]
        ]
"""
def self_attention(Q, K, V, mask = None):
    # Q, K, V: (batch_size, n_head, n_sequence_words, d_output_embedding)
    scale = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_output_embedding) # (batch_size, n_head, n_sequence_words, n_sequence_words)
    if mask is not None:
        scale = scale + mask
    attention = torch.nn.functional.softmax(scale, dim=-1)
    return torch.matmul(attention, V), attention # (batch_size, n_head, n_sequence_words, d_output_embedding) for the output of the self-attention

class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_input_embedding, d_output_embedding, n_head):
        super(MultiHeadAttention, self).__init__()
        self.d_input_embedding = d_input_embedding
        self.d_output_embedding = d_output_embedding
        self.n_head = n_head
        self.qkv_projection_layer = torch.nn.Linear(d_input_embedding, 3 * d_output_embedding)
        self.linear = torch.nn.Linear(d_output_embedding, d_output_embedding)

    def forward(self, X, mask = None):
        # X: (batch_size, n_sequence_words, d_input_embedding)
        batch_size, n_sequence_words, _ = X.shape
        # Q, K, V projection
        qkv = self.qkv_projection_layer(X)
        batch_size, n_sequence_words, _ = qkv.shape
        qkv = qkv.reshape(batch_size, n_sequence_words, n_head, -1) # (batch_size, n_sequence_words, n_head, 3 * d_output_embedding for q + k + v)
        qkv = qkv.permute(0, 2, 1, 3) # (batch_size, n_head, n_sequence_words, 3 * d_output_embedding for q + k + v) The calculation is per head
        q, k, v = qkv.chunk(3, dim=-1) # (batch_size, n_head, n_sequence_words, d_output_embedding) for q, k, v
        value, attention = self_attention(q, k, v, mask) # value: (batch_size, n_head, n_sequence_words, d_output_embedding) for the output of the self-attention
        value = value.reshape(batch_size, n_sequence_words, -1) # concat n_head
        return self.linear(value)




In [11]:
X = torch.randn(8, n_sequence_words, d_input_embedding)

tensor([[0.0952, 0.0939, 0.1098, 0.0944, 0.1034, 0.0937, 0.1075, 0.0990, 0.0937,
         0.1094],
        [0.0943, 0.0838, 0.1043, 0.0947, 0.1089, 0.1020, 0.1171, 0.0917, 0.1072,
         0.0960],
        [0.1146, 0.0904, 0.0852, 0.0817, 0.0942, 0.1058, 0.1208, 0.0851, 0.1144,
         0.1077],
        [0.0980, 0.0927, 0.1052, 0.0921, 0.1004, 0.0962, 0.1074, 0.1013, 0.0961,
         0.1106],
        [0.1028, 0.0899, 0.0928, 0.0939, 0.0973, 0.1074, 0.1050, 0.1032, 0.1089,
         0.0988],
        [0.0966, 0.0831, 0.0990, 0.0943, 0.1049, 0.1062, 0.1133, 0.0972, 0.1104,
         0.0950],
        [0.1077, 0.0960, 0.0962, 0.0866, 0.1021, 0.0972, 0.1227, 0.0780, 0.1059,
         0.1075],
        [0.1024, 0.1042, 0.0943, 0.1055, 0.1008, 0.1043, 0.0977, 0.0942, 0.1065,
         0.0902],
        [0.0923, 0.1066, 0.0968, 0.1235, 0.0994, 0.1068, 0.0784, 0.1165, 0.1009,
         0.0787],
        [0.0985, 0.0964, 0.1027, 0.0980, 0.1038, 0.0988, 0.1067, 0.0939, 0.1011,
         0.1000]], grad_fn=<

In [13]:
Encoder(X, mask_matrix)[1][0][0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.5504, 0.4496, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.4226, 0.2996, 0.2779, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2804, 0.2362, 0.2588, 0.2246, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2482, 0.1895, 0.1830, 0.1836, 0.1956, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1798, 0.1350, 0.1596, 0.1551, 0.1774, 0.1930, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1512, 0.1222, 0.1326, 0.1240, 0.1535, 0.1449, 0.1715, 0.0000, 0.0000,
         0.0000],
        [0.1316, 0.1231, 0.1198, 0.1336, 0.1336, 0.1336, 0.1234, 0.1012, 0.0000,
         0.0000],
        [0.1071, 0.1147, 0.1065, 0.1327, 0.1091, 0.1170, 0.0871, 0.1153, 0.1104,
         0.0000],
        [0.1031, 0.0943, 0.1004, 0.1007, 0.1068, 0.1051, 0.1060, 0.0896, 0.1018,
         0.0921]], grad_fn=<