In [6]:
import numpy as np

from keras.models import Model
from keras.layers import Dense, Dropout

In [None]:
class MultiheadAttention(Model):
    ## hidden_dim has to be multiples of head_num
    def __init__(self, hidden_dim=512, head_num=8, dropout_rate=0.1):
        super().__init__(*args, **kwargs)
        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.dropout_rate = dropout_rate
        
        self.q_dense_layer = Dense(hidden_dim, use_bias=False, name="q_dense_layer")
        self.k_dense_layer = Dense(hidden_dim, use_bias=False, name="k_dense_layer")
        self.v_dense_layer = Dense(hidden_dim, use_bias=False, name="v_dense_layer")
        self.output_dense_layer = Dense(hidden_dim, use_bias=False, name="output_dense_layer")
        self.attention_dropout_layer = Dropout(dropout_rate, name="attention_dropout_layer")
        
    def split_head(self, x):
        batch_size, seq_length, hidden_dim = x.shape()
        x = x.reshape(batch_size, seq_length, self.head_num, self.hidden_dim//self.head_num)
        return x.transpose(0, 2, 1, 3)
        
    def call(self, query, memory, attention_mask, for_train):
        q = self.q_dense_layer(query)
        k = self.k_dense_layer(memory)
        v = self.v_dense_layer(memory)
        
        q = self.split_head(q)
        k = self.split_head(k)
        v = self.split_head(v)
        
        #for scaled dot-product
        depth_inside_each_head = self.hidden_dim // self.head_num
        q *= depth_inside_each_head ** -0.5
        
        

In [7]:

test_array = np.array([[1,2], [3,4]])
test_array.shape

(2, 2)

In [16]:
test = np.arange(40)
print(test.shape)
test = test.reshape(5,4,2)
test

(40,)


array([[[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7]],

       [[ 8,  9],
        [10, 11],
        [12, 13],
        [14, 15]],

       [[16, 17],
        [18, 19],
        [20, 21],
        [22, 23]],

       [[24, 25],
        [26, 27],
        [28, 29],
        [30, 31]],

       [[32, 33],
        [34, 35],
        [36, 37],
        [38, 39]]])

In [22]:
test.reshape(5, 2, 4, 1)

array([[[[ 0],
         [ 1],
         [ 2],
         [ 3]],

        [[ 4],
         [ 5],
         [ 6],
         [ 7]]],


       [[[ 8],
         [ 9],
         [10],
         [11]],

        [[12],
         [13],
         [14],
         [15]]],


       [[[16],
         [17],
         [18],
         [19]],

        [[20],
         [21],
         [22],
         [23]]],


       [[[24],
         [25],
         [26],
         [27]],

        [[28],
         [29],
         [30],
         [31]]],


       [[[32],
         [33],
         [34],
         [35]],

        [[36],
         [37],
         [38],
         [39]]]])

In [20]:
test.transpose(2, 0, 1)

array([[[ 0,  2,  4,  6],
        [ 8, 10, 12, 14],
        [16, 18, 20, 22],
        [24, 26, 28, 30],
        [32, 34, 36, 38]],

       [[ 1,  3,  5,  7],
        [ 9, 11, 13, 15],
        [17, 19, 21, 23],
        [25, 27, 29, 31],
        [33, 35, 37, 39]]])