In [2]:
import tensorflow as tf
from keras import layers
from keras.layers import Dense, Attention, BatchNormalization, Input, Conv2D, Permute


In [3]:
class AttentionBlock(layers.Layer):
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.b_norm = layers.BatchNormalization()
        self.query = layers.Dense(units, kernel_initializer=tf.keras.initializers.RandomUniform())
        self.key = layers.Dense(units, kernel_initializer=tf.keras.initializers.RandomUniform())
        self.value = layers.Dense(units, kernel_initializer=tf.keras.initializers.RandomUniform())
        self.output_layer = layers.Dense(units, kernel_initializer=tf.keras.initializers.RandomUniform())

    def call(self, inputs):
        inputs = self.b_norm(inputs)
        q = self.query(inputs)
        k = self.key(inputs)
        v = self.value(inputs)

        batch, head, length, dim = k.shape
        product = tf.einsum("b h i d, b h j d -> b h i j", q, k)

        scale_product = product*dim**(-0.5)
        attention = tf.keras.activations.softmax(scale_product)
        output = tf.einsum('b h i j, b h j d -> b h i d', attention, v)
        print(output.shape)
        output = self.output_layer(output)
        print(output.shape)
        print(inputs.shape)
        return inputs + output


In [4]:
filters = 16
input_layer = Input(shape=(128,128,3))
conv_layer = Conv2D(strides=1, filters = filters, kernel_size=(3,3), padding='same')(input_layer)
attention_layer = AttentionBlock(units=filters)(conv_layer)

(None, 128, 128, 16)
(None, 128, 128, 16)
(None, 128, 128, 16)
