In [2]:
import tensorflow as tf
import tensorflow.keras.layers as kl

In [8]:
class SelfAttention(tf.keras.layers.Layer):
    def __init__(self, channels, **kwargs):
        super(SelfAttention, self).__init__(**kwargs)
        self.W_gate = tf.keras.Sequential(
            [ kl.Conv2D(channels, 1, 1, padding="same"), 
             kl.BatchNormalization()]
        )
        self.W_x = tf.keras.Sequential(
            [ kl.Conv2D(channels, 1, 1, padding="same"),
                kl.BatchNormalization()]
        )
        self.phi = tf.keras.Sequential(
            [ kl.Conv2D(channels, 1, 1, padding="same"),
                kl.BatchNormalization()]
        )

        self.attention = kl.Attention()

    def call(self, x, g ):
        g = self.W_gate(g)
        x1 = self.W_x(x)
        phi = g + x1
        phi = kl.Activation("relu")(phi)
        phi = self.phi(phi)
        phi = kl.Activation("sigmoid")(phi)

        phi = self.attention([x, phi])
        return phi * x


In [10]:
layer = SelfAttention(64)

x = tf.random.normal((1, 32, 32, 12))
g = tf.random.normal((1, 32, 32, 12))

layer(x, g).shape

TensorShape([1, 32, 32, 64])