In [0]:
import tensorflow as tf

In [0]:
layers_in_block = {
    'resnet-18' : [2, 2, 2, 2],
    'resnet-34' : [3, 4, 6, 3],
    'resnet-50' : [3, 4, 6, 3],
    'resnet-101' : [3, 4, 23, 3],
    'resnet-152' : [3, 8, 36, 3]
    }

class ResNet_Dense(tf.keras.layers.Layer):
  def __init__(self, pooling='avg', classes=1):
    super(ResNet_Dense, self).__init__()
    self.classes = classes
    self.pooling = pooling

    self.avg_pooling = tf.keras.layers.GlobalAveragePooling2D()
    self.max_pooling = tf.keras.layers.GlobalMaxPooling2D()

    self.sigmoid_fc = tf.keras.layers.Dense(units=classes, activation=tf.keras.activations.sigmoid)
    self.softmax_fc = tf.keras.layers.Dense(units=classes, activation=tf.keras.activations.softmax)


  def call(self, inputs):
    if self.pooling == 'avg':
      x = self.avg_pooling(inputs)
    elif self.pooling == 'max':
      x = self.max_pooling(inputs)

    if self.classes == 1:
      x = self.sigmoid_fc(x)
    else:
      x = self.softmax_fc(x)
      
    return x

def make_dense_layer(pooling='avg', classes=1):
  res_dense = tf.keras.Sequential()
  res_dense.add(ResNet_Dense(pooling=pooling, classes=classes))
  return res_dense

## dimension을 맞추기 위해 stride 2 설정(basic, bottle commit)

In [0]:
class Basic_building_block(tf.keras.layers.Layer):
  def __init__(self, filter_num=None, kernel_size=(3, 3), \
               stride=1, padding='same'):
    super(Basic_building_block, self).__init__()
    
    #First layer of the block
    self.conv1 = tf.keras.layers.Conv2D(filters=filter_num,
                                       kernel_size=kernel_size,
                                       strides=stride,
                                       padding=padding)
    self.bn1 = tf.keras.layers.BatchNormalization()

    #Second layer of the block
    self.conv2 = tf.keras.layers.Conv2D(filters=filter_num,
                                        kernel_size=kernel_size,
                                        strides=1,
                                        padding=padding)
    self.bn2 = tf.keras.layers.BatchNormalization()

    if stride != 1:
      self.downsample = tf.keras.Sequential()
      self.downsample.add(tf.keras.layers.Conv2D(filters=filter_num,
                                                  kernel_size=(1, 1),
                                                  strides=stride))
      self.downsample.add(tf.keras.layers.BatchNormalization())
    else:
      self.downsample = lambda x: x

  def call(self, inputs, training=None):
    x_shortcut = self.downsample(inputs)

    x = self.conv1(inputs)
    x = self.bn1(x, training=training)
    x = tf.nn.relu(x)
    x = self.conv2(x)
    x = self.bn2(x, training=training)
    x = tf.nn.relu(tf.keras.layers.add([x_shortcut, x]))
    
    return x

def make_basic_block_layer(filter_num, blocks, stride=1):
  res_block = tf.keras.Sequential()
  for i in range(blocks):
    if stride==2 and i == 0:
      res_block.add(Basic_building_block(filter_num, stride=stride))
    else:
      res_block.add(Basic_building_block(filter_num, stride=1))
  return res_block

class BasicResNet(tf.keras.Model):
  def __init__(self, layer='resnet-34', classes=1):
    super(BasicResNet, self).__init__()

    self.conv1 = tf.keras.layers.Conv2D(filters=64, kernel_size=(7, 7), strides=2, padding='same', kernel_initializer='he_normal')
    self.bn1 = tf.keras.layers.BatchNormalization(axis=3)
    self.max_pooling = tf.keras.layers.MaxPooling2D(pool_size=(3, 3), strides=2)

    self.res_block1 = make_basic_block_layer(64, layers_in_block[layer][0], 1)
    self.res_block2 = make_basic_block_layer(128, layers_in_block[layer][1], 2)
    self.res_block3 = make_basic_block_layer(256, layers_in_block[layer][2], 2)
    self.res_block4 = make_basic_block_layer(512, layers_in_block[layer][3], 2)

    self.fc_block1 = make_dense_layer('avg', classes)
    
  def call(self, inputs, training=None):
    x = self.conv1(inputs)
    x = self.bn1(x, training=training)
    x = tf.nn.relu(x)
    x = self.max_pooling(x)

    x = self.res_block1(x)
    x = self.res_block2(x)
    x = self.res_block3(x)
    x = self.res_block4(x)

    x = self.fc_block1(x)
    
    return x

In [0]:
class SEBlock(tf.keras.layers.Layer):
    def __init__(self, input_channels, r=16):
        super(SEBlock, self).__init__()
        self.pool = tf.keras.layers.GlobalAveragePooling2D()
        self.fc1 = tf.keras.layers.Dense(units=input_channels // r)
        self.fc2 = tf.keras.layers.Dense(units=input_channels)

    def call(self, inputs, **kwargs):
        branch = self.pool(inputs)
        branch = self.fc1(branch)
        branch = tf.nn.relu(branch)
        branch = self.fc2(branch)
        branch = tf.nn.sigmoid(branch)
        branch = tf.expand_dims(input=branch, axis=1)
        branch = tf.expand_dims(input=branch, axis=1)
        output = tf.keras.layers.multiply(inputs=[inputs, branch])
        return output


class ChannelAttention(tf.keras.layers.Layer):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg= tf.keras.layers.GlobalAveragePooling2D()
        self.max= tf.keras.layers.GlobalMaxPooling2D()
        self.conv1 = tf.keras.layers.Conv2D(filters = in_planes//ratio, 
                                   kernel_size=1,
                                   kernel_initializer='he_normal', 
                                   strides=1,
                                   padding='same')
        
        self.conv2 = tf.keras.layers.Conv2D(filters = in_planes, 
                                            kernel_size=1,
                                            kernel_initializer='he_normal', 
                                            strides=1, 
                                            padding='same')
                                   
    def call(self, inputs):
        avg = self.avg(inputs)
        max = self.max(inputs)
        avg = tf.keras.layers.Reshape((1, 1, avg.shape[1]))(avg)   # shape (None, 1, 1 feature)
        max = tf.keras.layers.Reshape((1, 1, max.shape[1]))(max)   # shape (None, 1, 1 feature)
        avg_out = self.conv2(self.conv1(avg))
        avg_out = tf.nn.relu(avg_out)
        max_out = self.conv2(self.conv1(max))
        max_out = tf.nn.relu(max_out)
        out = avg_out + max_out
        out = tf.nn.sigmoid(out)

        return out

class SpatialAttention(tf.keras.layers.Layer):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(filters = 1,
                                            kernel_size = kernel_size,
                                            kernel_initializer='he_normal',
                                            strides = 1,
                                            padding='same')
    def call(self, inputs):
        avg_out = tf.reduce_mean(inputs, axis=3)
        max_out = tf.reduce_max(inputs, axis=3)
        out = tf.stack([avg_out, max_out], axis=3)
        out = self.conv1(out)
        out = tf.nn.relu(out)

        return out

class ConvBlockAttentionModule(tf.keras.layers.Layer):
    def __init__(self, out_channels, ratio = 16, kernel_size = 7):
        super(ConvBlockAttentionModule, self).__init__()
        self.ca = ChannelAttention(in_planes = out_channels,
                                  ratio = ratio)
        self.sa = SpatialAttention(kernel_size= kernel_size)

    def call(self, inputs, **kwargs):
      out = self.ca(inputs) * inputs
      out = self.sa(out) * out

      return out
      

class Bottle_building_block(tf.keras.layers.Layer):
    def __init__(self, filter_num=None, stride=2, se_block=False, cbam_block=False):
      super(Bottle_building_block, self).__init__()
      #First layer of the block
      self.conv1 = tf.keras.layers.Conv2D(filters=filter_num,
                                        kernel_size=(1, 1),
                                        strides=1,
                                        padding='valid')
      self.bn1 = tf.keras.layers.BatchNormalization()

      #Second layer of the block
      self.conv2 = tf.keras.layers.Conv2D(filters=filter_num,
                                          kernel_size=(3, 3),
                                          strides=stride,
                                          padding='same')
      self.bn2 = tf.keras.layers.BatchNormalization()

      #Three layer of the block
      self.conv3 = tf.keras.layers.Conv2D(filters=filter_num*4,
                                          kernel_size=(1, 1),
                                          strides=1,
                                          padding='valid')
      self.bn3 = tf.keras.layers.BatchNormalization()

      #Convolution Attention Block
      self.use_se = se_block
      self.se_block = SEBlock(input_channels=filter_num * 4)

      self.use_cbam = cbam_block
      self.cbam_block = ConvBlockAttentionModule(out_channels= filter_num * 4)

      self.downsample = tf.keras.Sequential()
      self.downsample.add(tf.keras.layers.Conv2D(filters=filter_num*4,
                                                kernel_size=(1, 1),
                                                strides=stride,
                                                padding='valid'))
      self.downsample.add(tf.keras.layers.BatchNormalization())
    
    def call(self, inputs, training=None):
      x_shortcut = self.downsample(inputs)

      x = self.conv1(inputs)
      x = self.bn1(x, training=training)
      x = tf.nn.relu(x)
      x = self.conv2(x)
      x = self.bn2(x, training=training)
      x = tf.nn.relu(x)
      x = self.conv3(x)
      x = self.bn3(x, training=training)
      if self.use_se == True:
        x = self.se_block(x)
      if self.use_cbam == True:
        x = self.cbam_block(x)

      x = tf.nn.relu(tf.keras.layers.add([x_shortcut, x]))
  
      return x

def make_bottle_block_layer(filter_num, blocks, stride=2, use_se=False, use_cbam=False):
  res_block = tf.keras.Sequential()
  for i in range(blocks):
    if i == 0:
      res_block.add(Bottle_building_block(filter_num, stride=stride, se_block=use_se, cbam_block=use_cbam))
    else:
      res_block.add(Bottle_building_block(filter_num, stride=1, se_block=use_se, cbam_block=use_cbam))
  return res_block

class BottleResNet(tf.keras.Model):
  def __init__(self, layer='resnet-101', se_block=False, cbam_block=False, classes=1):
    super(BottleResNet, self).__init__()

    self.zero_padd_1 = tf.keras.layers.ZeroPadding2D(padding=(3, 3))
    self.conv1 = tf.keras.layers.Conv2D(filters=64, 
                                        kernel_size=(7, 7), 
                                        strides=2)
    self.bn1 = tf.keras.layers.BatchNormalization()
    self.zero_padd_2 = tf.keras.layers.ZeroPadding2D(padding=(1, 1))
    self.max_pooling = tf.keras.layers.MaxPooling2D(pool_size=(3, 3), 
                                                    strides=2)

    self.res_block1 = make_bottle_block_layer(64, layers_in_block[layer][0], 1, use_se=se_block, use_cbam=cbam_block)
    self.res_block2 = make_bottle_block_layer(128, layers_in_block[layer][1], 2, use_se=se_block, use_cbam=cbam_block)
    self.res_block3 = make_bottle_block_layer(256, layers_in_block[layer][2], 2, use_se=se_block, use_cbam=cbam_block)
    self.res_block4 = make_bottle_block_layer(512, layers_in_block[layer][3], 2, use_se=se_block, use_cbam=cbam_block)

    self.fc_block1 = make_dense_layer('avg', classes)
    
  def call(self, inputs, training=None, cam = None, feature_map = None):
    zero_padd_1 = self.zero_padd_1(inputs)
    conv_1 = self.conv1(zero_padd_1)
    bn_1 = self.bn1(conv_1, training=training)
    relu_1 = tf.nn.relu(bn_1)
    zero_padd_2 = self.zero_padd_2(relu_1)
    max_pool_1 = self.max_pooling(zero_padd_2)

    res_block_1 = self.res_block1(max_pool_1)
    res_block_2 = self.res_block2(res_block_1)
    res_block_3 = self.res_block3(res_block_2)
    res_block_4 = self.res_block4(res_block_3)

    x = self.fc_block1(res_block_4)
    if cam == 'grad':
      return res_block_4, x
    
    if feature_map == 'fm':
      return conv_1, bn_1, relu_1, max_pool_1, res_block_1, res_block_2, res_block_3, res_block_4, x

    return x

In [0]:
def resnet_18(classes):
  return BasicResNet(layer='resnet-18', classes=classes)

def resnet_34(classes):
  return BasicResNet(layer='resnet-34', classes=classes)

def resnet_50(classes):
  return BottleResNet(layer='resnet-50', classes=classes)

def resnet_101(classes):
  return BottleResNet(layer='resnet-101', classes=classes)

def resnet_152(classes):
  return BottleResNet(layer='resnet-152', classes=classes)

def se_resnet_50(classes):
  return BottleResNet(layer='resnet-50', se_block=True, classes=classes)

def se_resnet_101(classes):
  return BottleResNet(layer='resnet-101', se_block=True, classes=classes)

def se_resnet_152(classes):
  return BottleResNet(layer='resnet-152', se_block=True, classes=classes)

def cbam_resnet_50(classes):
  return BottleResNet(layer='resnet-50', cbam_block=True, classes=classes)

def cbam_resnet_101(classes):
  return BottleResNet(layer='resnet-101', cbam_block=True, classes=classes)

def cbam_resnet_152(classes):
  return BottleResNet(layer='resnet-152', cbam_block=True, classes=classes)