In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
sequence_length = 4 # maximum sequence length
batch_size = 1
input_dim = 512 # vector dimension
d_model = 512 # output dimension
x = torch.randn( (batch_size, sequence_length, input_dim) )

In [4]:
x.size()

torch.Size([1, 4, 512])

In [5]:
qkv_layer = nn.Linear(input_dim, 3*d_model) # all concatenated

In [9]:
qkv = qkv_layer(x)

In [11]:
qkv.shape

torch.Size([1, 4, 1536])

In [12]:
num_heads = 8
head_dim = d_model // num_heads # 64
qkv = qkv.reshape(batch_size, sequence_length, num_heads, 3*head_dim)

In [14]:
qkv.shape

torch.Size([1, 4, 8, 192])

In [16]:
qkv = qkv.permute(0, 2, 1, 3) # switch
qkv.shape

torch.Size([1, 8, 4, 192])

In [17]:
q, k, v = qkv.chunk(3, dim=-1) # chuncking very last dimension
q.shape, k.shape, v.shape

(torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]))

In [19]:
import math
d_k = q.size()[-1] # 64
# transpose last 2 dimension (sequence_length, head_dim size)
# dot product is not worked because it's not matrix, it's tensor.
scaled = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

scaled.shape
scaled.size()

torch.Size([1, 8, 4, 4])

In [24]:
# masking (in decoder)

mask = torch.full(scaled.size(), float('-inf'))
mask = torch.triu(mask, diagonal=1)
mask[0][1]

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [26]:
(scaled+mask)[0][0]

tensor([[ 0.0325,    -inf,    -inf,    -inf],
        [-0.1608, -0.5618,    -inf,    -inf],
        [ 0.3121,  0.4271, -0.5060,    -inf],
        [ 0.0263,  0.0561, -0.2624, -0.1564]], grad_fn=<SelectBackward0>)

In [28]:
attention = F.softmax(scaled,dim=-1) # row by row

In [30]:
attention[0][0]

tensor([[0.2331, 0.1941, 0.2471, 0.3256],
        [0.2424, 0.1623, 0.2478, 0.3475],
        [0.3350, 0.3758, 0.1478, 0.1414],
        [0.2768, 0.2852, 0.2074, 0.2306]], grad_fn=<SelectBackward0>)

In [31]:
values = torch.matmul(attention, v)