In [69]:
import tensorflow as tf
import numpy as np

x = tf.keras.Input(shape=[4, 3])
layer  = tf.keras.layers.MultiHeadAttention(num_heads=1, key_dim=2, use_bias=False)

output_tensor = layer(x, x)
print(output_tensor.shape)

(None, 4, 3)


In [70]:
weights = layer.get_weights()
print(len(weights))

4


In [71]:
print(weights[0].shape)
print(weights[1].shape)
print(weights[2].shape)
print(weights[3].shape)

(3, 1, 2)
(3, 1, 2)
(3, 1, 2)
(1, 2, 3)


In [72]:
q = np.array([[[ 0.4,  0.3 ]],
              [[-0.1, -0.1]],
              [[ 0.2, -0.1]]])
k = np.array([[[ 0.1,  0.2 ]],
              [[-0.3, -0.4]],
              [[-0.1,  0.2]]])
v = np.array([[[-0.2,  0.1 ]],
              [[-0.4,  0.2]],
              [[ 0.4, -0.6]]])
o = np.array([[[ 0.1, -0.1,  0.6 ],
               [ 0.9,  0.3,  0.1 ]]])

In [73]:
layer.set_weights([q, k, v, o])

In [74]:
weights = layer.get_weights()
print(weights[0])
print(weights[1])
print(weights[2])
print(weights[3])

[[[ 0.4  0.3]]

 [[-0.1 -0.1]]

 [[ 0.2 -0.1]]]
[[[ 0.1  0.2]]

 [[-0.3 -0.4]]

 [[-0.1  0.2]]]
[[[-0.2  0.1]]

 [[-0.4  0.2]]

 [[ 0.4 -0.6]]]
[[[ 0.1 -0.1  0.6]
  [ 0.9  0.3  0.1]]]


In [75]:
data = np.array([1., 3., 2., 6., 2., 1., 5., 8., 4., 7., 3., 4.])
data = data.reshape((1, 4, 3))
print(data.shape)
print(data)

(1, 4, 3)
[[[1. 3. 2.]
  [6. 2. 1.]
  [5. 8. 4.]
  [7. 3. 4.]]]


In [76]:
output_tensor, weights = layer(data, data, return_attention_scores=True)
print(output_tensor.shape)
print(weights.shape)

(1, 4, 3)
(1, 1, 4, 4)


In [77]:
print(output_tensor)
print(weights)

tf.Tensor(
[[[-0.43691295  0.03797075 -0.85948306]
  [-0.3176294   0.07249315 -0.8230913 ]
  [-0.28671634  0.07962245 -0.80563337]
  [-0.21573429  0.11351576 -0.8429406 ]]], shape=(1, 4, 3), dtype=float32)
tf.Tensor(
[[[[2.6413780e-01 3.0642202e-01 1.8679065e-01 2.4264953e-01]
   [3.5392467e-02 5.8209985e-01 1.6682872e-03 3.8083941e-01]
   [1.2176767e-01 5.6085670e-01 1.6345147e-02 3.0103046e-01]
   [2.4871059e-02 6.6630757e-01 5.4240436e-04 3.0827892e-01]]]], shape=(1, 1, 4, 4), dtype=float32)


## verify

In [1]:
import tensorflow as tf
import numpy as np

W_Q = np.array([[[ 0.4,  0.3 ]],
                [[-0.1, -0.1]],
                [[ 0.2, -0.1]]])
W_K = np.array([[[ 0.1,  0.2 ]],
                [[-0.3, -0.4]],
                [[-0.1,  0.2]]])
W_V = np.array([[[-0.2,  0.1 ]],
                [[-0.4,  0.2]],
                [[ 0.4, -0.6]]])

W_Q = W_Q.reshape((3, 2))
W_K = W_K.reshape((3, 2))
W_V = W_V.reshape((3, 2))

print(W_Q)
print(W_K)
print(W_V)

[[ 0.4  0.3]
 [-0.1 -0.1]
 [ 0.2 -0.1]]
[[ 0.1  0.2]
 [-0.3 -0.4]
 [-0.1  0.2]]
[[-0.2  0.1]
 [-0.4  0.2]
 [ 0.4 -0.6]]


In [2]:
data = np.array([1., 3., 2., 6., 2., 1., 5., 8., 4., 7., 3., 4.])
data = data.reshape((4, 3))
print(data.shape)
print(data)

(4, 3)
[[1. 3. 2.]
 [6. 2. 1.]
 [5. 8. 4.]
 [7. 3. 4.]]


In [3]:
Q = np.dot(data, W_Q)
K = np.dot(data, W_K)
V = np.dot(data, W_V)
 
print(Q)
print(K)
print(V)

[[ 0.5 -0.2]
 [ 2.4  1.5]
 [ 2.   0.3]
 [ 3.3  1.4]]
[[-1.  -0.6]
 [-0.1  0.6]
 [-2.3 -1.4]
 [-0.6  1. ]]
[[-0.6 -0.5]
 [-1.6  0.4]
 [-2.6 -0.3]
 [-1.  -1.1]]


In [6]:
def softmax(x):
    max_x = x.max(axis=1)
    max_x = max_x.reshape(max_x.shape[0], 1)
    
    e_x = np.exp(x-max_x)
    
    sum_e = e_x.sum(axis=1)    
    sum_e = sum_e.reshape(sum_e.shape[0], 1)
    
    return e_x / sum_e

alpha = softmax(np.dot(Q, K.T) / np.sqrt(2))
print(alpha)

[[2.64137789e-01 3.06422025e-01 1.86790669e-01 2.42649516e-01]
 [3.53924695e-02 5.82099808e-01 1.66828771e-03 3.80839435e-01]
 [1.21767661e-01 5.60856713e-01 1.63451555e-02 3.01030471e-01]
 [2.48710601e-02 6.66307594e-01 5.42404741e-04 3.08278941e-01]]


In [23]:
np.dot(Q, K.T) / np.sqrt(2)

array([[-0.26870058, -0.12020815, -0.6151829 , -0.35355339],
       [-2.33345238,  0.46669048, -5.38815367,  0.04242641],
       [-1.54149278, -0.01414214, -3.54967604, -0.6363961 ],
       [-2.92742207,  0.36062446, -6.75286976, -0.41012193]])

In [20]:
context_vector = np.dot(alpha, V)
print(context_vector)

[[-1.37706317 -0.33245175]
 [-1.33777216 -0.20428018]
 [-1.31395921 -0.17257821]
 [-1.39070398 -0.08518205]]


In [21]:
W_O = np.array([[[ 0.1, -0.1,  0.6 ],
               [ 0.9,  0.3,  0.1 ]]])
W_O = W_O.reshape((2, 3))
print(W_O)

[[ 0.1 -0.1  0.6]
 [ 0.9  0.3  0.1]]


In [22]:
output = np.dot(context_vector, W_O)
print(output)

[[-0.43691289  0.03797079 -0.85948308]
 [-0.31762937  0.07249316 -0.82309131]
 [-0.28671631  0.07962246 -0.80563335]
 [-0.21573424  0.11351578 -0.84294059]]


In [None]:
# keras
[[[-0.43691295  0.03797075 -0.85948306]
  [-0.3176294   0.07249315 -0.8230913 ]
  [-0.28671634  0.07962245 -0.80563337]
  [-0.21573429  0.11351576 -0.8429406 ]]]