In [0]:
import tensorflow as tf

In [0]:
layer_in_block = {'vgg11': [1, 1, 2, 2, 2],
                  'vgg13': [2, 2, 2, 2, 2],
                  'vgg16': [2, 2, 3, 3, 3],
                  'vgg19': [2, 2, 4, 4, 4]}

class SEBlock(tf.keras.layers.Layer):
    def __init__(self, input_channels, r=16):
        super(SEBlock, self).__init__()
        self.pool = tf.keras.layers.GlobalAveragePooling2D()
        self.fc1 = tf.keras.layers.Dense(units=input_channels // r)
        self.fc2 = tf.keras.layers.Dense(units=input_channels)

    def call(self, inputs, **kwargs):
        branch = self.pool(inputs)
        branch = self.fc1(branch)
        branch = tf.nn.relu(branch)
        branch = self.fc2(branch)
        branch = tf.nn.sigmoid(branch)
        branch = tf.expand_dims(input=branch, axis=1)
        branch = tf.expand_dims(input=branch, axis=1)
        output = tf.keras.layers.multiply(inputs=[inputs, branch])
        return output


class ChannelAttention(tf.keras.layers.Layer):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg= tf.keras.layers.GlobalAveragePooling2D()
        self.max= tf.keras.layers.GlobalMaxPooling2D()
        self.conv1 = tf.keras.layers.Conv2D(filters = in_planes//ratio, 
                                   kernel_size=1,
                                   kernel_initializer='he_normal', 
                                   strides=1,
                                   padding='same')
        
        self.conv2 = tf.keras.layers.Conv2D(filters = in_planes, 
                                            kernel_size=1,
                                            kernel_initializer='he_normal', 
                                            strides=1, 
                                            padding='same')
                                   
    def call(self, inputs):
        avg = self.avg(inputs)
        max = self.max(inputs)
        avg = tf.keras.layers.Reshape((1, 1, avg.shape[1]))(avg)   # shape (None, 1, 1 feature)
        max = tf.keras.layers.Reshape((1, 1, max.shape[1]))(max)   # shape (None, 1, 1 feature)
        avg_out = self.conv2(self.conv1(avg))
        avg_out = tf.nn.relu(avg_out)
        max_out = self.conv2(self.conv1(max))
        max_out = tf.nn.relu(max_out)
        out = avg_out + max_out
        out = tf.nn.sigmoid(out)

        return out

class SpatialAttention(tf.keras.layers.Layer):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(filters = 1,
                                            kernel_size = kernel_size,
                                            kernel_initializer='he_normal',
                                            strides = 1,
                                            padding='same')
    def call(self, inputs):
        avg_out = tf.reduce_mean(inputs, axis=3)
        max_out = tf.reduce_max(inputs, axis=3)
        out = tf.stack([avg_out, max_out], axis=3)
        out = self.conv1(out)
        out = tf.nn.relu(out)

        return out

class ConvBlockAttentionModule(tf.keras.layers.Layer):
    def __init__(self, out_channels, ratio = 16, kernel_size = 7):
        super(ConvBlockAttentionModule, self).__init__()
        self.ca = ChannelAttention(in_planes = out_channels,
                                  ratio = ratio)
        self.sa = SpatialAttention(kernel_size= kernel_size)

    def call(self, inputs, **kwargs):
      out = self.ca(inputs) * inputs
      out = self.sa(out) * out

      return out

class VggConv(tf.keras.layers.Layer):
  def __init__(self, filter_num=None, kernel_size=(3, 3), \
               activation='relu', padding='same', kernel_initializer='he_normal'):
    super(VggConv, self).__init__()
    self.conv1 = tf.keras.layers.Conv2D(filters=filter_num,
                                        kernel_size=kernel_size,
                                        activation=activation,
                                        padding=padding,
                                        kernel_initializer=kernel_initializer)   
    
  def call(self, inputs, training=None):
    x = self.conv1(inputs)
    return x

class VggConvs(tf.keras.layers.Layer):
  def __init__(self, pool_size=(2, 2)):
    super(VggConvs, self).__init__()
    self.bn1 = tf.keras.layers.BatchNormalization()
    
    self.pool1 = tf.keras.layers.MaxPool2D(pool_size=pool_size)
  
  def call(self, inputs, training=None):
    x = self.bn1(inputs, training=training)
    x = self.pool1(x)
    return x

class VggDense(tf.keras.layers.Layer):
  def __init__(self, filter_num=None, classes=1):
    super(VggDense, self).__init__()
    self.Flatten = tf.keras.layers.Flatten()
    self.Dense1 = tf.keras.layers.Dense(filter_num, activation='relu')
    self.bn1 = tf.keras.layers.BatchNormalization()
    if classes == 1:
      self.Dense2 = tf.keras.layers.Dense(classes, activation=tf.keras.activations.sigmoid)
    else:
      self.Dense2 = tf.keras.layers.Dense(classes, activation=tf.keras.activations.softmax)

  def call(self, inputs):
    x = self.Flatten(inputs)
    x = self.Dense1(x)
    x = self.bn1(x)
    x = self.Dense2(x)

    return x


def vgg_convs_layer(filter_num=None, blocks=None,  kernel_size=(3, 3), \
                    activation='relu', padding='same', kernel_initializer='he_normal', \
                    pool_size=(2, 2), use_se = False, use_cbam = False):
  vgg_block = tf.keras.Sequential()
  for i in range(blocks):
    vgg_block.add(VggConv(filter_num=filter_num,  kernel_size=kernel_size, activation=activation, padding=padding, kernel_initializer=kernel_initializer))
  vgg_block.add(VggConvs(pool_size=pool_size))
  if use_se == True:
    vgg_block.add(SEBlock(input_channels=filter_num))
  if use_cbam == True:
    vgg_block.add(ConvBlockAttentionModule(out_channels = filter_num))
  return vgg_block


class VggNet(tf.keras.Model):
  def __init__(self, layer='vgg16', use_se = False, use_cbam = False, classes=1):
    super(VggNet, self).__init__()
 
    self.conv1 = vgg_convs_layer(filter_num = 64, blocks = layer_in_block[layer][0], use_se = use_se, use_cbam=use_cbam)
    self.conv2 = vgg_convs_layer(filter_num = 128, blocks =  layer_in_block[layer][1], use_se = use_se, use_cbam=use_cbam)
    self.conv3 = vgg_convs_layer(filter_num = 256, blocks = layer_in_block[layer][2], use_se = use_se, use_cbam=use_cbam)
    self.conv4 = vgg_convs_layer(filter_num = 512, blocks = layer_in_block[layer][3], use_se = use_se, use_cbam=use_cbam)
    self.conv5 = vgg_convs_layer(filter_num = 512, blocks = layer_in_block[layer][4], use_se = use_se, use_cbam=use_cbam)
    self.dense = VggDense(filter_num = 256, classes = classes)

  def call(self, inputs, cam=None, feature_map=None):
    sequential_1 = self.conv1(inputs)
    sequential_2 = self.conv2(sequential_1)
    sequential_3 = self.conv3(sequential_2)
    sequential_4 = self.conv4(sequential_3)
    sequential_5 = self.conv5(sequential_4)
    x = self.dense(sequential_5)

    if cam == 'grad':
      return sequential_5, x

    if feature_map == 'fm':
      return sequential_1, sequential_2, sequential_3, sequential_4, sequential_5, x
    
    return x

In [0]:
def vgg_11(classes):
  return VggNet(layer='vgg11', classes=classes)

def vgg_13(classes):
  return VggNet(layer='vgg13', classes=classes)

def vgg_16(classes):
  return VggNet(layer='vgg16', classes=classes)

def vgg_19(classes):
  return VggNet(layer='vgg19', classes=classes)

def se_vgg_16(classes):
  return VggNet(layer='vgg16', use_se = True, classes=classes)

def se_vgg_19(classes):
  return VggNet(layer='vgg19', use_se = True, classes=classes)

def cbam_vgg_16(classes):
  return VggNet(layer='vgg16', use_cbam = True, classes=classes)

def cbam_vgg_19(classes):
  return VggNet(layer='vgg19', use_cbam = True, classes=classes)
