In [None]:
from tensorflow import keras
from keras.ops import mean , max
from keras.layers import Concatenate , Reshape , Conv2D , BatchNormalization , Activation , Multiply , Add

In [None]:
# tensorflow is channel last
class resnetTripletAtt() : 
    def __init__(self , k_size):
        self.k_size = k_size
    #---
    def __Z_pool(self,tensor) : 
        avgpool = mean(tensor , axis=-1 , keepdims=True)
        maxpool = max(tensor  , axis=-1 , keepdims=True)
        concat  = Concatenate(axis=-1)([maxpool , avgpool])
        return concat
    #---
    def __branch_H_C(self , tensor): # branch one
        channel = tensor.shape(-1)
        width   = tensor.shape(-2)
        height  = tensor.shape(-3) 
        tensor_hat             = Reshape((height , channel , width))(tensor)
        tensor_hat_star        = self.__Z_pool(tensor_hat)
        tensor_hat_star_conv   = Conv2D(filters=1 , kernel_size=self.k_size , strides=1 , padding="same")(tensor_hat_star)
        tensor_hat_star_conv_N = BatchNormalization(axis=-1)(tensor_hat_star_conv)
        attention_map          = Activation(activation='sigmoid')(tensor_hat_star_conv_N)
        attention_out  = Multiply()([tensor_hat , attention_map])
        rotated_tensor = Reshape((height , width , channel))(attention_out)
        return rotated_tensor
    def __branch_W_C(self , tensor): # branch two
        channel = tensor.shape(-1)
        width   = tensor.shape(-2)
        height  = tensor.shape(-3) 
        tensor_hat             = Reshape((channel , width , height))(tensor)
        tensor_hat_star        = self.__Z_pool(tensor_hat)
        tensor_hat_star_conv   = Conv2D(filters=1 , kernel_size=self.k_size , strides=1 , padding="same")(tensor_hat_star)
        tensor_hat_star_conv_N = BatchNormalization(axis=-1)(tensor_hat_star_conv)
        attention_map          = Activation(activation='sigmoid')(tensor_hat_star_conv_N)
        attention_out  = Multiply()([tensor_hat , attention_map])
        rotated_tensor = Reshape((height , width , channel))(attention_out)
        return rotated_tensor
    def __branch_identify(self , tensor): # branch identify (third branch)
        tensor_hat        = self.__Z_pool(tensor)
        tensor_hat_conv   = Conv2D(filters=1 , kernel_size=self.k_size , strides=1 , padding="same")(tensor_hat)
        tensor_hat_conv_N = BatchNormalization(axis=-1)(tensor_hat_conv)
        attention_map          = Activation(activation='sigmoid')(tensor_hat_conv_N)
        attention_out  = Multiply()([tensor_hat , attention_map])
        return attention_out
    def triplet_Attention(self , tensor) :
        br_1_out  = self.__branch_H_C(tensor)
        br_2_out  = self.__branch_W_C(tensor)
        br_3_out  = self.__branch_identify(tensor)
        #---
        tensor_out = Multiply()([Add()([br_1_out , br_2_out , br_3_out])] , 1/3)
    def __call__(self , tensor):
        return self.triplet_Attention(tensor)