# Multi Head Attention

In [44]:
from numpy import random
from tensorflow import matmul, math, cast, float32, reshape, shape, transpose
from tensorflow.keras.layers import Dense,Layer
from tensorflow.keras.backend import softmax

In [45]:
class DotProductAttention(Layer):
    def __init__(self,**kwargs):
        super().__init__(**kwargs)
    
    def call(self, queries, keys, values, d_k, mask=None):
        scores = matmul(queries,keys,transpose_b=True)/math.sqrt(cast(d_k,float32))
        
        #Apply mask to the attention scores
        if mask is not None:
            scores += -1e9 * mask
        
        #Computing the weights by a softmax operation
        weights = softmax(scores)
        
        #Computing the attention by weighted sum of the value vectors
        return matmul(weights, values)

In [63]:
class MultiHeadAttention(Layer):
    def __init__(self, h, d_k, d_v, d_model, **kwargs):
        super().__init__(**kwargs)
        self.attention = DotProductAttention() #Scaled dot product attention
        self.heads = h # Number of attention heads to use
        self.d_k = d_k # Dimensionality of the linearly projected queries and keys
        self.d_v = d_v # Dimensionality of the linearly projected values
        self.W_q = Dense(d_k) # Learned projection matrix for the queries
        self.W_k = Dense(d_k) # Learned projection matrix for the keys
        self.W_v = Dense(d_v) # Learned projection matrix for the values
        self.W_o = Dense(d_model) # Learned projection matrix for the multi-head output
            
    def reshape_tensor(self, x, heads, flag):
        if flag:
            x = reshape(x, shape=(shape(x)[0], shape(x)[1], heads, -1))
            x = transpose(x, perm=(0, 2, 1, 3))
        else:
            x = transpose(x, perm=(0, 2, 1, 3))
            x = reshape(x, shape=(shape(x)[0], shape(x)[1], -1))
        return x
    
    def call(self, queries, keys, values, mask=None):
        q_reshaped = self.reshape_tensor(self.W_q(queries), self.heads, True)
        k_reshaped = self.reshape_tensor(self.W_k(keys), self.heads, True)
        v_reshaped = self.reshape_tensor(self.W_v(values), self.heads, True)
        o_reshaped = self.attention(q_reshaped, k_reshaped, v_reshaped, self.d_k, mask)
        
        output = self.reshape_tensor(o_reshaped, self.heads, False)
        return self.W_o(output)

### Test your code

In [64]:
h = 8 # Number of self-attention heads
d_k = 64 # Dimensionality of the linearly projected queries and keys 
d_v = 64 # Dimensionality of the linearly projected values
d_model = 512 # Dimensionality of the model sub-layers' outputs 
batch_size = 64 # Batch size from the training process

In [65]:
input_seq_length = 5 # Maximum length of the input sequence
queries = random.random((batch_size, input_seq_length, d_k))
keys = random.random((batch_size, input_seq_length, d_k))
values = random.random((batch_size, input_seq_length, d_v))

In [66]:
multihead_attention = MultiHeadAttention(h, d_k, d_v, d_model)

In [67]:
print(multihead_attention(queries, keys, values))

tf.Tensor(
[[[-0.11264286 -0.01115484 -0.16644283 ... -0.23892288 -0.09416207
   -0.11787184]
  [-0.11201126 -0.01037188 -0.16619633 ... -0.2415398  -0.09289355
   -0.1176058 ]
  [-0.11209292 -0.01087035 -0.16752669 ... -0.23917238 -0.09657675
   -0.11846104]
  [-0.11138599 -0.01168628 -0.16623661 ... -0.23790407 -0.0932257
   -0.11753162]
  [-0.11168005 -0.00991426 -0.16730416 ... -0.23987913 -0.09263724
   -0.11674593]]

 [[-0.24275436 -0.14067279 -0.25067586 ... -0.13608369 -0.13150424
   -0.0601126 ]
  [-0.24326378 -0.13740389 -0.25057405 ... -0.13658503 -0.1285262
   -0.06044139]
  [-0.2414822  -0.13956553 -0.25132072 ... -0.13959934 -0.12791435
   -0.06222455]
  [-0.24204466 -0.14133568 -0.24985853 ... -0.1375808  -0.13070369
   -0.06164463]
  [-0.244182   -0.14027199 -0.2513249  ... -0.13857861 -0.12941486
   -0.0616059 ]]

 [[-0.23782581 -0.06970622 -0.34440106 ... -0.09732933 -0.2258077
   -0.16578905]
  [-0.23736416 -0.07232311 -0.3429694  ... -0.09635789 -0.22988911
   -0.16