In [1]:
import tensorflow as tf

from tensorflow.keras import *

class BasicConv(layers.Layer):
    def __init__(self, out_planes, kernel_size):
        super(BasicConv, self).__init__()
        self.conv = tf.keras.layers.Conv2D(
            out_planes,
            kernel_size=[kernel_size, kernel_size],
            strides = 1,
            padding=(kernel_size - 1) // 2,
            use_bias=False,)
        self.bn = tf.keras.layers.BatchNormalization(
                momentum=0.999,
                epsilon=1e-5)

    def call(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = tf.nn.relu(x)
        return x

class ChannelPool(layers.Layer):
    def call(self, x):
        return tf.concat([tf.expand_dims(tf.reduce_max(x, axis=3), axis=3),
                          tf.expand_dims(tf.reduce_mean(x, axis=3), axis=3)], axis=3)

class SpatialGate(layers.Layer):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 3
        self.compress = ChannelPool()
        self.spatial = BasicConv(1, kernel_size)

    def call(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = tf.nn.sigmoid(x_out)
        return x * scale

class TripletAttention(layers.Layer):
    def __init__(self, no_spatial=False):
        super(TripletAttention, self).__init__()
        self.ChannelGateH = SpatialGate()
        self.ChannelGateW = SpatialGate()
        self.no_spatial = no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()

    def call(self, x):
        x_perm1 = tf.transpose(x, perm=[0, 3, 2, 1])
        x_out1 = self.ChannelGateH(x_perm1)  
        x_out11 = tf.transpose(x_out1, perm=[0, 3, 2, 1])
        x_perm2 = tf.transpose(x, perm=[0, 1, 3, 2])
        x_out2 = self.ChannelGateW(x_perm2)  
        x_out21 = tf.transpose(x_out2, perm=[0, 1, 3, 2])
        if not self.no_spatial:
            x_out = self.SpatialGate(x)  
            x_out = (1/3) * (x_out + x_out11 + x_out21)
        else:
            x_out = (1/2) * (x_out11 + x_out21)
        return x_out