In [1]:
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
import math

In [2]:
batch_size = 1
max_seqlen = 200
d_model = 512
n_heads = 8
x = tf.random.normal((batch_size, max_seqlen, d_model))

In [7]:
def scaled_dot_product(q, k, v, mask = None):
  scaled_dotproduct = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) / math.sqrt(q.shape[-1])
  attention_weights = tf.nn.softmax(scaled_dotproduct, axis = -1)
  if mask is not None:
    attention_weights += mask
  values = tf.matmul(attention_weights, v)
  return values, attention_weights

class MultiHeadAttention(tf.keras.Model):
  def __init__(self, d_model, n_heads):
    super().__init__()
    self.n_heads = n_heads
    self.head_dim = d_model // n_heads
    self.qkv_dense = Dense(d_model*3)
    self.out_dense = Dense(d_model)

  def call(self, inputs, mask=None):
    batch_size, max_seqlen, d_model = inputs.shape
    qkv = self.qkv_dense(inputs)
    qkv = tf.reshape(qkv, (batch_size, max_seqlen, self.n_heads, 3*self.head_dim))
    qkv = tf.transpose(qkv, perm=[0, 2, 1, 3])
    q, k, v = tf.split(qkv, 3, axis = -1)
    values, attention_weights = scaled_dot_product(q, k, v, mask)
    values = tf.reshape(values, (batch_size, max_seqlen, self.n_heads*self.head_dim))
    out = self.out_dense(values)
    return out

In [8]:
mha = MultiHeadAttention(d_model, n_heads)
mha(x)

<tf.Tensor: shape=(1, 200, 512), dtype=float32, numpy=
array([[[ 0.01795023, -0.12091614,  0.09743282, ..., -0.03686188,
         -0.09144451, -0.01387916],
        [-0.01523148, -0.07521614,  0.09003463, ...,  0.00406478,
         -0.07066263, -0.04080007],
        [ 0.0294656 , -0.07237387,  0.04118561, ..., -0.00842303,
         -0.06011126, -0.02426854],
        ...,
        [ 0.06071714,  0.0051684 , -0.04929529, ..., -0.01382769,
          0.03571779, -0.05440907],
        [ 0.08154583, -0.00788801, -0.05433886, ..., -0.01039406,
          0.09396115,  0.04686199],
        [ 0.10124335,  0.0557582 , -0.01287371, ..., -0.04106165,
          0.04083134,  0.03867629]]], dtype=float32)>