In [0]:
import tensorflow as tf

In [0]:
class SEBlock(tf.keras.layers.Layer):
    def __init__(self, input_channels, r=16):
        super(SEBlock, self).__init__()
        self.pool = tf.keras.layers.GlobalAveragePooling2D()
        self.fc1 = tf.keras.layers.Dense(units=input_channels // r)
        self.fc2 = tf.keras.layers.Dense(units=input_channels)

    def call(self, inputs, **kwargs):
        branch = self.pool(inputs)
        branch = self.fc1(branch)
        branch = tf.nn.relu(branch)
        branch = self.fc2(branch)
        branch = tf.nn.sigmoid(branch)
        branch = tf.expand_dims(input=branch, axis=1)
        branch = tf.expand_dims(input=branch, axis=1)
        output = tf.keras.layers.multiply(inputs=[inputs, branch])
        return output


class ChannelAttention(tf.keras.layers.Layer):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg= tf.keras.layers.GlobalAveragePooling2D()
        self.max= tf.keras.layers.GlobalMaxPooling2D()
        self.conv1 = tf.keras.layers.Conv2D(filters = in_planes//ratio, 
                                   kernel_size=1,
                                   kernel_initializer='he_normal', 
                                   strides=1,
                                   padding='same')
        
        self.conv2 = tf.keras.layers.Conv2D(filters = in_planes, 
                                            kernel_size=1,
                                            kernel_initializer='he_normal', 
                                            strides=1, 
                                            padding='same')
                                   
    def call(self, inputs):
        avg = self.avg(inputs)
        max = self.max(inputs)
        avg = tf.keras.layers.Reshape((1, 1, avg.shape[1]))(avg)   # shape (None, 1, 1 feature)
        max = tf.keras.layers.Reshape((1, 1, max.shape[1]))(max)   # shape (None, 1, 1 feature)
        avg_out = self.conv2(self.conv1(avg))
        avg_out = tf.nn.relu(avg_out)
        max_out = self.conv2(self.conv1(max))
        max_out = tf.nn.relu(max_out)
        out = avg_out + max_out
        out = tf.nn.sigmoid(out)

        return out

class SpatialAttention(tf.keras.layers.Layer):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(filters = 1,
                                            kernel_size = kernel_size,
                                            kernel_initializer='he_normal',
                                            strides = 1,
                                            padding='same')
    def call(self, inputs):
        avg_out = tf.reduce_mean(inputs, axis=3)
        max_out = tf.reduce_max(inputs, axis=3)
        out = tf.stack([avg_out, max_out], axis=3)
        out = self.conv1(out)
        out = tf.nn.relu(out)

        return out

class ConvBlockAttentionModule(tf.keras.layers.Layer):
    def __init__(self, out_channels, ratio = 16, kernel_size = 7):
        super(ConvBlockAttentionModule, self).__init__()
        self.ca = ChannelAttention(in_planes = out_channels,
                                  ratio = ratio)
        self.sa = SpatialAttention(kernel_size= kernel_size)

    def call(self, inputs, **kwargs):
      out = self.ca(inputs) * inputs
      out = self.sa(out) * out

      return out


class MobileNet_classification(tf.keras.layers.Layer):
  def __init__(self, pooling='avg', classes=1):
    super(MobileNet_classification, self).__init__()
    self.classes = classes
    self.pooling = pooling

    self.avg_pooling = tf.keras.layers.GlobalAveragePooling2D()
    self.max_pooling = tf.keras.layers.GlobalMaxPooling2D()

    self.sigmoid_fc = tf.keras.layers.Dense(units=classes, activation=tf.keras.activations.sigmoid)
    self.softmax_fc = tf.keras.layers.Dense(units=classes, activation=tf.keras.activations.softmax)


  def call(self, inputs):
    if self.pooling == 'avg':
      x = self.avg_pooling(inputs)
    elif self.pooling == 'max':
      x = self.max_pooling(inputs)

    if self.classes == 1:
      x = self.sigmoid_fc(x)
    else:
      x = self.softmax_fc(x)
      
    return x

class depthwise_separable_convolution(tf.keras.layers.Layer):
  def __init__(self, filter_num, stride, padd="same", use_bias=False, use_se=False, use_cbam=False):
    super(depthwise_separable_convolution, self).__init__()
    #dw
    self.dw_conv_1 = tf.keras.layers.DepthwiseConv2D(kernel_size=(3, 3),
                                                strides=stride,
                                                padding=padd,
                                                use_bias=use_bias)
    self.bn1 = tf.keras.layers.BatchNormalization()

    #pw
    self.conv_2 = tf.keras.layers.Conv2D(filters=filter_num,
                                        kernel_size=(1, 1),
                                        strides=1,
                                        padding="same",
                                        use_bias=use_bias)
    self.bn2 = tf.keras.layers.BatchNormalization()

    #Convolution Attention Block
    self.use_se = use_se
    self.se_block = SEBlock(input_channels=filter_num)
    self.use_cbam = use_cbam
    self.cbam_block = ConvBlockAttentionModule(out_channels=filter_num)


  def call(self, inputs):
    x = self.dw_conv_1(inputs)
    x = self.bn1(x)
    x = tf.nn.relu(x)

    x = self.conv_2(x)
    x = self.bn2(x)
    x = tf.nn.relu(x)

    if self.use_se == True:
      x = self.se_block(x)
    if self.use_cbam == True:
      x = self.cbam_block(x)
    
    return x
    

class MobileNet(tf.keras.Model):
  def __init__(self, include_top=True, classes=1, use_se=False, use_cbam=False):
    super(MobileNet, self).__init__()

    self.include_top = include_top

    self.zero_padd_1 = tf.keras.layers.ZeroPadding2D(padding=((0, 1), (0, 1)))
    self.conv_1 = tf.keras.layers.Conv2D(filters=32,
                                        kernel_size=(3, 3),
                                        strides=2,
                                        padding="valid",
                                        use_bias=False)
    self.bn_1 = tf.keras.layers.BatchNormalization()


    self.dw_separable_block_1 = depthwise_separable_convolution(filter_num=64,
                                                                stride=1,
                                                                padd="same",
                                                                use_se=use_se,
                                                                use_cbam=use_cbam)
    
    self.zero_padd_2 = tf.keras.layers.ZeroPadding2D(padding=((0, 1), (0, 1)))

    self.dw_separable_block_2 = depthwise_separable_convolution(filter_num=128,
                                                                stride=2,
                                                                padd="valid",
                                                                use_se=use_se,
                                                                use_cbam=use_cbam)
    
    self.dw_separable_block_3 = depthwise_separable_convolution(filter_num=128,
                                                                stride=1,
                                                                padd="same",
                                                                use_se=use_se,
                                                                use_cbam=use_cbam)
    
    self.zero_padd_3 = tf.keras.layers.ZeroPadding2D(padding=((0, 1), (0, 1)))
    
    self.dw_separable_block_4 = depthwise_separable_convolution(filter_num=256,
                                                                stride=2,
                                                                padd="valid",
                                                                use_se=use_se,
                                                                use_cbam=use_cbam)
    
    self.dw_separable_block_5 = depthwise_separable_convolution(filter_num=256,
                                                                stride=1,
                                                                padd="same",
                                                                use_se=use_se,
                                                                use_cbam=use_cbam)
    
    self.zero_padd_4 = tf.keras.layers.ZeroPadding2D(padding=((0, 1), (0, 1)))
    
    self.dw_separable_block_6 = depthwise_separable_convolution(filter_num=512,
                                                                stride=2,
                                                                padd="valid",
                                                                use_se=use_se,
                                                                use_cbam=use_cbam)
    
    self.dw_separable_block_7 = depthwise_separable_convolution(filter_num=512,
                                                                stride=1,
                                                                padd="same",
                                                                use_se=use_se,
                                                                use_cbam=use_cbam)
    
    self.zero_padd_5 = tf.keras.layers.ZeroPadding2D(padding=((0, 1), (0, 1)))
    
    self.dw_separable_block_8 = depthwise_separable_convolution(filter_num=1204,
                                                                stride=2,
                                                                padd="valid",
                                                                use_se=use_se,
                                                                use_cbam=use_cbam)
    
    self.dw_separable_block_9 = depthwise_separable_convolution(filter_num=1024,
                                                                stride=2,
                                                                padd="same",
                                                                use_se=use_se,
                                                                use_cbam=use_cbam)
    self.fc_1 = MobileNet_classification(pooling='avg', 
                                         classes=classes)
  def call(self, inputs):
    x = self.zero_padd_1(inputs)
    x = self.conv_1(x)
    x = self.bn_1(x)
    x = tf.nn.relu(x)
    x = self.dw_separable_block_1(x)
    x = self.zero_padd_2(x)
    x = self.dw_separable_block_2(x)
    x = self.dw_separable_block_3(x)
    x = self.zero_padd_3(x)
    x = self.dw_separable_block_4(x)
    x = self.dw_separable_block_5(x)
    x = self.zero_padd_4(x)
    x = self.dw_separable_block_6(x)
    for _ in range(5):
      x = self.dw_separable_block_7(x)
    x = self.zero_padd_5(x)
    x = self.dw_separable_block_8(x)
    x = self.dw_separable_block_9(x)
    if self.include_top:
      x = self.fc_1(x)
    return x

In [0]:
def MobileNet_V1(classes):
  return MobileNet(include_top=True, classes=classes)

def se_MobileNet_V1(classes):
  return MobileNet(include_top=True, classes=classes, use_se=True)

def cbam_MobileNet_V1(classes):
  return MobileNet(include_top=True, classes=classes, use_cbam=True)