In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

In [0]:
class SEBlock(tf.keras.layers.Layer):
    def __init__(self, input_channels, r=16):
        super(SEBlock, self).__init__()
        self.pool = tf.keras.layers.GlobalAveragePooling2D()
        self.fc1 = tf.keras.layers.Dense(units=input_channels // r)
        self.fc2 = tf.keras.layers.Dense(units=input_channels)

    def call(self, inputs, **kwargs):
        branch = self.pool(inputs)
        branch = self.fc1(branch)
        branch = tf.nn.relu(branch)
        branch = self.fc2(branch)
        branch = tf.nn.sigmoid(branch)
        branch = tf.expand_dims(input=branch, axis=1)
        branch = tf.expand_dims(input=branch, axis=1)
        output = tf.keras.layers.multiply(inputs=[inputs, branch])
        return output

class ChannelAttention(tf.keras.layers.Layer):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg= tf.keras.layers.GlobalAveragePooling2D()
        self.max= tf.keras.layers.GlobalMaxPooling2D()
        self.conv1 = tf.keras.layers.Conv2D(filters = in_planes//ratio, 
                                            kernel_size=1,
                                            kernel_initializer='he_normal', 
                                            strides=1,
                                            padding='same')
        
        self.conv2 = tf.keras.layers.Conv2D(filters = in_planes, 
                                            kernel_size=1,
                                            kernel_initializer='he_normal', 
                                            strides=1, 
                                            padding='same')
                                   
    def call(self, inputs):
        avg = self.avg(inputs)
        max = self.max(inputs)
        avg = tf.keras.layers.Reshape((1, 1, avg.shape[1]))(avg)   # shape (None, 1, 1 feature)
        max = tf.keras.layers.Reshape((1, 1, max.shape[1]))(max)   # shape (None, 1, 1 feature)
        avg_out = self.conv2(self.conv1(avg))
        avg_out = tf.nn.relu(avg_out)
        max_out = self.conv2(self.conv1(max))
        max_out = tf.nn.relu(max_out)
        out = avg_out + max_out
        out = tf.nn.sigmoid(out)

        return out

class SpatialAttention(tf.keras.layers.Layer):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(filters = 1,
                                            kernel_size = kernel_size,
                                            kernel_initializer='he_normal',
                                            strides = 1,
                                            padding='same')
    def call(self, inputs):
        avg_out = tf.reduce_mean(inputs, axis=3)
        max_out = tf.reduce_max(inputs, axis=3)
        out = tf.stack([avg_out, max_out], axis=3)
        out = self.conv1(out)
        out = tf.nn.relu(out)

        return out

class ConvBlockAttentionModule(tf.keras.layers.Layer):
    def __init__(self, out_channels, ratio = 16, kernel_size = 7):
        super(ConvBlockAttentionModule, self).__init__()
        self.ca = ChannelAttention(in_planes = out_channels,
                                  ratio = ratio)
        self.sa = SpatialAttention(kernel_size= kernel_size)

    def call(self, inputs, **kwargs):
      out = self.ca(inputs) * inputs
      out = self.sa(out) * out

      return out

class ConvBlock(tf.keras.Model):
  def __init__(self, num_filters, data_format, weight_decay=1e-4,
               dropout_rate=0, use_se=False, use_cbam=False):
    super(ConvBlock, self).__init__()

    axis = -1 if data_format == "channels_last" else 1
    inter_filter = num_filters * 4

    #Convolution Attention Block
    self.use_se = use_se
    self.se_block = SEBlock(input_channels=num_filters)
    
    self.use_cbam = use_cbam
    self.cbam_block = ConvBlockAttentionModule(out_channels=num_filters)

    # don't forget to set use_bias=False when using batchnorm
    self.conv1 = tf.keras.layers.Conv2D(inter_filter,
                                        (1, 1),
                                        padding="same",
                                        #use_bias=False,
                                        #data_format=data_format,
                                        kernel_initializer="he_normal",
                                        #kernel_regularizer=l2(weight_decay)
                                        )
    self.dropout = tf.keras.layers.Dropout(dropout_rate)

    self.conv2 = tf.keras.layers.Conv2D(num_filters,
                                        (3, 3),
                                        padding="same",
                                        #use_bias=False,
                                        #data_format=data_format,
                                        kernel_initializer="he_normal",
                                        #kernel_regularizer=l2(weight_decay)
                                        )
    self.batchnorm1 = tf.keras.layers.BatchNormalization(axis=axis)
    self.batchnorm2 = tf.keras.layers.BatchNormalization(axis=axis)


  def call(self, inputs, training=True):
    x = self.batchnorm1(inputs, training=training)

    x = self.conv1(tf.nn.relu(x))
    x = self.batchnorm2(x, training=training)

    x = self.conv2(tf.nn.relu(x))
    x = self.dropout(x, training=training)

    if self.use_se == True:
      x = self.se_block(x)
    if self.use_cbam == True:
      x = self.cbam_block(x)

    return x


class TransitionBlock(tf.keras.Model):
  def __init__(self, num_filters, data_format,
               weight_decay=1e-4, dropout_rate=0):
    super(TransitionBlock, self).__init__()
    axis = -1 if data_format == "channels_last" else 1

    self.batchnorm = tf.keras.layers.BatchNormalization(axis=axis)
    self.conv = tf.keras.layers.Conv2D(num_filters,
                                       (1, 1),
                                       padding="same",
                                       #use_bias=False,
                                       #data_format=data_format,
                                       kernel_initializer="he_normal",
                                       #kernel_regularizer=l2(weight_decay)
                                       )
    self.avg_pool = tf.keras.layers.AveragePooling2D(data_format=data_format)

  def call(self, inputs, training=True):
    x = self.batchnorm(inputs, training=training)
    x = self.conv(tf.nn.relu(x))
    x = self.avg_pool(x)
    return x


class DenseBlock(tf.keras.Model):
  def __init__(self, num_layers, growth_rate, data_format,
               weight_decay=1e-4, dropout_rate=0, use_se=False, use_cbam=False):
    super(DenseBlock, self).__init__()
    self.num_layers = num_layers
    self.axis = -1 if data_format == "channels_last" else 1

    self.blocks = []
    for _ in range(int(self.num_layers)):
      self.blocks.append(ConvBlock(growth_rate,
                                   data_format,
                                   weight_decay,
                                   dropout_rate,
                                   use_se,
                                   use_cbam))

  def call(self, inputs, training=True):
    for i in range(int(self.num_layers)):
      if i == 0:
        x = self.blocks[i](inputs, training=training)
      else:
        x = self.blocks[i](x, training=training)
      x = tf.concat([x, inputs], axis=self.axis)

    return x


class DenseNet(tf.keras.Model):
  def __init__(self, growth_rate, output_classes, num_of_blocks=None, 
               num_layers_in_each_block=None, data_format="channels_last", 
               compression=0.5, weight_decay=1e-4, dropout_rate=0., 
               include_top=True, use_se=False, use_cbam=False):
    super(DenseNet, self).__init__() 
    self.growth_rate = growth_rate
    self.num_of_blocks = num_of_blocks
    self.output_classes = output_classes
    self.num_layers_in_each_block = num_layers_in_each_block
    self.data_format = data_format
    self.compression = compression
    self.weight_decay = weight_decay
    self.dropout_rate = dropout_rate
    self.include_top = include_top
    self.use_se = use_se
    self.use_cbam = use_cbam

    axis = -1 if self.data_format == "channels_last" else 1

    self.num_filters = 2 * self.growth_rate

    # first conv and pool layer
    self.conv1 = tf.keras.layers.Conv2D(self.num_filters,
                                        (7, 7),
                                        strides=(2, 2),
                                        padding="same",
                                        #use_bias=False,
                                        #data_format=self.data_format,
                                        kernel_initializer="he_normal",
                                        #kernel_regularizer=l2(self.weight_decay)
                                        )
    
    self.pool1 = tf.keras.layers.MaxPooling2D(pool_size=(3, 3),
                                              strides=(2, 2),
                                              padding="same",
                                              data_format=self.data_format)
    self.batchnorm1 = tf.keras.layers.BatchNormalization(axis=axis)

    self.batchnorm2 = tf.keras.layers.BatchNormalization(axis=axis)

    # calculating the number of filters after each block
    num_filters_after_each_block = [self.num_filters]
    for i in range(1, self.num_of_blocks):
      temp_num_filters = num_filters_after_each_block[i-1] + (
          self.growth_rate * self.num_layers_in_each_block[i-1])
      # using compression to reduce the number of inputs to the
      # transition block
      temp_num_filters = int(temp_num_filters * compression)
      num_filters_after_each_block.append(temp_num_filters)

    # dense block initialization
    self.dense_blocks = []
    self.transition_blocks = []
    for i in range(self.num_of_blocks):
      self.dense_blocks.append(DenseBlock(self.num_layers_in_each_block[i],
                                          self.growth_rate,
                                          self.data_format,
                                          self.weight_decay,
                                          self.dropout_rate,
                                          self.use_se,
                                          self.use_cbam))
      if i+1 < self.num_of_blocks:
        self.transition_blocks.append(
            TransitionBlock(num_filters_after_each_block[i+1],
                            self.data_format,
                            self.weight_decay,
                            self.dropout_rate))

    # last pooling and fc layer
    if self.include_top:
      self.last_pool = tf.keras.layers.GlobalAveragePooling2D(
          data_format=self.data_format)
      self.classifier = tf.keras.layers.Dense(self.output_classes)

  def call(self, inputs, training=None, cam=None):
    x = self.conv1(inputs)
    x = self.batchnorm1(x, training=training)
    x = tf.nn.relu(x)
    x = self.pool1(x)

    for i in range(self.num_of_blocks - 1):
      x = self.dense_blocks[i](x, training=training)
      x = self.transition_blocks[i](x, training=training)

    x = self.dense_blocks[self.num_of_blocks - 1](x, training=training)
    x = self.batchnorm2(x, training=training)
    x_cam = tf.nn.relu(x)

    if self.include_top:
      x = self.last_pool(x_cam)
      x = self.classifier(x)

    if cam == 'grad':
      return x_cam, x

    return x

In [0]:
def densenet_121(output_classes):
  return DenseNet(growth_rate=32, output_classes=output_classes, num_of_blocks=4, 
num_layers_in_each_block=[6, 12, 24, 16], dropout_rate=0.3)

def densenet_169(output_classes):
  return DenseNet(growth_rate=32, output_classes=output_classes, num_of_blocks=4, 
num_layers_in_each_block=[6, 12, 32, 32], data_format=None, dropout_rate=0.4)

def densenet_201(output_classes):
  return DenseNet(growth_rate=32, output_classes=output_classes, num_of_blocks=4, 
num_layers_in_each_block=[6, 12, 48, 32], data_format=None, dropout_rate=0.5)

def densenet_265(output_classes):
  return DenseNet(growth_rate=32, output_classes=output_classes, num_of_blocks=4, 
num_layers_in_each_block=[6, 12, 64, 48], data_format=None, dropout_rate=0.5)

In [0]:
def se_densenet_121(output_classes):
  DenseNet(growth_rate=32, output_classes=output_classes, num_of_blocks=4, 
num_layers_in_each_block=[6, 12, 24, 16], data_format=None, use_se=True)

def se_densenet_169(output_classes):
  DenseNet(growth_rate=32, output_classes=output_classes, num_of_blocks=4, 
num_layers_in_each_block=[6, 12, 32, 32], data_format=None, use_se=True)

def se_densenet_201(output_classes):
  DenseNet(growth_rate=32, output_classes=output_classes, num_of_blocks=4, 
num_layers_in_each_block=[6, 12, 48, 32], data_format=None, use_se=True)

def se_densenet_265(output_classes):
  DenseNet(growth_rate=32, output_classes=output_classes, num_of_blocks=4, 
num_layers_in_each_block=[6, 12, 64, 48], data_format=None, use_se=True)

In [0]:
def cbam_densenet_121(output_classes):
  DenseNet(growth_rate=32, output_classes=output_classes, num_of_blocks=4, 
num_layers_in_each_block=[6, 12, 24, 16], data_format=None, use_cbam=True)

def cbam_densenet_169(output_classes):
  DenseNet(growth_rate=32, output_classes=output_classes, num_of_blocks=4, 
num_layers_in_each_block=[6, 12, 32, 32], data_format=None, use_cbam=True)

def cbam_densenet_201(output_classes):
  DenseNet(growth_rate=32, output_classes=output_classes, num_of_blocks=4, 
num_layers_in_each_block=[6, 12, 48, 32], data_format=None, use_cbam=True)

def cbam_densenet_265(output_classes):
  DenseNet(growth_rate=32, output_classes=output_classes, num_of_blocks=4, 
num_layers_in_each_block=[6, 12, 64, 48], data_format=None, use_cbam=True)