In [2]:
import tensorflow as tf
from tensorflow.keras import layers

In [3]:
def scaled_dot_product_attention(queries, keys, values, mask):
    product = tf.matmul(queries, keys, transpose_b=True)
    keys_dim = tf.cast(tf.shape(keys)[-1], tf.float32)
    scaled_product = product / tf.math.sqrt(keys_dim)
    if mask is not None:
        scaled_product += (mask * -1e9)
    attention = tf.matmul(tf.nn.softmax(scaled_product, axis=-1), values)
    return attention

In [4]:
class MultiHeadAttention1(layers.Layer):
    def __init__(self, d_model, num_heads, length):
        super().__init__()
        self.n_heads = num_heads
        self.d_model = d_model
        self.length = length
        
    def build(self, input_shape):
        self.d_model = input_shape[-1]
        assert self.d_model % self.n_heads == 0
        self.d_head = self.d_model // self.n_heads
        self.query_lin = layers.Dense(units=self.d_model)
        self.key_lin = layers.Dense(units=self.d_model)
        self.value_lin = layers.Dense(units=self.d_model)
        self.final_lin = layers.Dense(units=self.d_model)
        
    def split_proj(self, inputs, batch_size):
        shape = (batch_size, -1, self.n_heads, self.d_head)
        splitted_inputs = tf.reshape(inputs, shape=shape)
        return tf.transpose(splitted_inputs, perm=[0,2,1,3])
    
    def call(self, queries, keys, values, mask):
        batch_size = tf.shape(queries)[0]
        queries = self.query_lin(queries)
        keys = self.key_lin(keys)
        values = self.value_lin(values)
        queries = self.split_proj(queries, batch_size)
        keys = self.split_proj(keys, batch_size)
        values = self.split_proj(values, batch_size)
        attention = scaled_dot_product_attention(queries, keys, values, mask)
        attention = tf.transpose(attention, perm=[0,2,1,3])
        concat_attention = tf.reshape(attention, shape=(batch_size, -1, self.d_model))
        outputs = self.final_lin(concat_attention)
        return outputs

In [None]:
class MultiHeadAttention(layers.Layer):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % self.num_heads == 0
        self.depth = d_model // self.num_heads
        self.wq = layers.Dense(d_model, activation='relu')
        self.wk = layers.Dense(d_model, activation='relu')
        self.wv = layers.Dense(d_model, activation='relu')
        self.dense = layers.Dense(d_model, activation='relu')
        
    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    def _generate_relative_position_matrix(self, length, max_relative_position, cache=False):
        if not cache:
            range_vec = tf.range(length)
            range_mat = tf.reshape(tf.tile(range_vec, [length]), [length, length])
            distance_mat = range_mat - tf.transpose(range_mat)
        else:
            distance_mat = tf.expand_dims(tf.range(-length+1, 1, 1), 0)
        distance_mat_clipped = tf.clip_by_value(distance_mat, -max_relative_position, max_relative_position)
        final_mat = distance_mat_clipped + max_relative_position
    
    def _generate_relative_positions_embeddings(self, length, depth, max_relative_position, name, cache=False):
        with tf.compat.v1.variable_scope(name):
            relative_positions_matrix = self._generate_relative_positions_matrix(length, max_relative_position, cache=cache)
            vocab_size = max_relative_position * 2 + 1
            embeddings_table = tf.compat.v1.get_variable(name, [vocab_size, depth])
            embeddings = tf.gather(embeddings_table, relative_positions_matrix)
            return embeddings
    
    def _relative_attention_inner(self, x, y, z, transpose, mask, length):
        batch_size = tf.shape(x)[0]
        heads = N_HEADS
        xy_matmul = tf.matmul(x, y, transpose_b=transpose)
        x_t = tf.transpose(x, [2, 0, 1, 3])
        x_t_r = tf.reshape(x_t, [length, heads * batch_size, -1])
        x_tz_matmul = tf.matmul(x_t_r, z, transpose_b=transpose)
        x_tz_matmul_r = tf.reshape(x_tz_matmul, [length, batch_size, heads, -1])
        x_tz_matmul_r_t = tf.transpose(x_tz_matmul_r, [1,2,0,3])
        return tf.math.add(xy_matmul, x_tz_matmul_r_t)

In [8]:
tf.reshape(tf.tile(tf.range(5), [5]), [5, 5]) - tf.transpose(tf.reshape(tf.tile(tf.range(5), [5]), [5, 5]))

<tf.Tensor: shape=(5, 5), dtype=int32, numpy=
array([[ 0,  1,  2,  3,  4],
       [-1,  0,  1,  2,  3],
       [-2, -1,  0,  1,  2],
       [-3, -2, -1,  0,  1],
       [-4, -3, -2, -1,  0]], dtype=int32)>

In [9]:
tf.expand_dims(tf.range(-5+1, 1, 1), 0)

<tf.Tensor: shape=(1, 5), dtype=int32, numpy=array([[-4, -3, -2, -1,  0]], dtype=int32)>

In [12]:
tf.clip_by_value(tf.reshape(tf.tile(tf.range(5), [5]), [5, 5]) - tf.transpose(tf.reshape(tf.tile(tf.range(5), [5]), [5, 5])), -2, 2)

<tf.Tensor: shape=(5, 5), dtype=int32, numpy=
array([[ 0,  1,  2,  2,  2],
       [-1,  0,  1,  2,  2],
       [-2, -1,  0,  1,  2],
       [-2, -2, -1,  0,  1],
       [-2, -2, -2, -1,  0]], dtype=int32)>

In [13]:
tf.clip_by_value(tf.reshape(tf.tile(tf.range(5), [5]), [5, 5]) - tf.transpose(tf.reshape(tf.tile(tf.range(5), [5]), [5, 5])), -2, 2)+2

<tf.Tensor: shape=(5, 5), dtype=int32, numpy=
array([[2, 3, 4, 4, 4],
       [1, 2, 3, 4, 4],
       [0, 1, 2, 3, 4],
       [0, 0, 1, 2, 3],
       [0, 0, 0, 1, 2]], dtype=int32)>