In [1]:
import tensorflow as tf

In [2]:
class SeqConvBlock(tf.keras.layers.Layer):
    def __init__(self, filter_num, k=1, s=1, p='valid'):
        super(SeqConvBlock, self).__init__()
        self.seqconv = tf.keras.layers.SeparableConv2D(filters=filter_num,
                                                       kernel_size=k,
                                                       strides=s,
                                                       padding=p)
        self.bn = tf.keras.layers.BatchNormalization()

    def call(self, inputs, training=None, **kwargs):
        x = tf.nn.relu(inputs)
        x = self.seqconv(x)
        out = self.bn(x, training=training)

        return out

In [3]:
class Block(tf.keras.layers.Layer):
    def __init__(self, filter_num):
        super(Block, self).__init__()
        self.seqconv1 = SeqConvBlock(filter_num=filter_num, k=3, s=1, p='same')
        self.seqconv2 = SeqConvBlock(filter_num=filter_num, k=3, s=1, p='same')
        self.seqconv3 = SeqConvBlock(filter_num=filter_num, k=3, s=1, p='same')

    def call(self, inputs, training=None, **kwargs):
        identify = inputs
        x = self.seqconv1(inputs)
        x = self.seqconv2(x)
        x = self.seqconv3(x)

        out = tf.keras.layers.add([identify, x])
        return out

In [4]:
class DownBlock(tf.keras.layers.Layer):
    def __init__(self, filter_num):
        super(DownBlock, self).__init__()
        self.seqconv1 = SeqConvBlock(filter_num=filter_num, k=3, s=1, p='same')
        self.seqconv2 = SeqConvBlock(filter_num=filter_num, k=3, s=1, p='same')
        self.maxpool = tf.keras.layers.MaxPool2D(pool_size=(3, 3),
                                                 strides=2,
                                                 padding='same')
        self.skip = tf.keras.layers.Conv2D(filters=filter_num,
                                           kernel_size=1,
                                           strides=2)

    def call(self, inputs, training=None, **kwargs):
        identify = self.skip(inputs)
        x = self.maxpool(self.seqconv2(self.seqconv1(inputs)))
        out = tf.keras.layers.add([identify, x])

        return out

In [9]:
class Xception(tf.keras.Model):
    def __init__(self, num_classes=10):
        super(Xception, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(filters=32,
                                            kernel_size=3,
                                            strides=2)
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.conv2 = tf.keras.layers.Conv2D(filters=64,
                                            kernel_size=3,
                                            padding='same')
        self.bn2 = tf.keras.layers.BatchNormalization()

        self.block1 = DownBlock(128)
        self.block2 = DownBlock(256)
        self.block3 = DownBlock(728)

        self.block4 = Block(728)
        self.block5 = Block(728)
        self.block6 = Block(728)
        self.block7 = Block(728)
        self.block8 = Block(728)
        self.block9 = Block(728)
        self.block10 = Block(728)
        self.block11 = Block(728)

        self.seqconv12 = SeqConvBlock(filter_num=728, k=3, s=1, p='same')
        self.seqconv13 = SeqConvBlock(filter_num=1024, k=3, s=1, p='same')
        self.maxpool14 = tf.keras.layers.MaxPool2D(pool_size=(3, 3),
                                                   strides=2,
                                                   padding='same')
        self.skip15 = tf.keras.layers.Conv2D(filters=1024,
                                             kernel_size=1,
                                             strides=2)

        self.conv16 = tf.keras.layers.SeparableConv2D(filters=1536,
                                                      kernel_size=3,
                                                      strides=1,
                                                      padding='same')
        self.bn16 = tf.keras.layers.BatchNormalization()

        self.conv17 = tf.keras.layers.SeparableConv2D(filters=2048,
                                                      kernel_size=3,
                                                      strides=1,
                                                      padding='same')
        self.bn17 = tf.keras.layers.BatchNormalization()

        self.avgpool = tf.keras.layers.GlobalAveragePooling2D()
        self.fc = tf.keras.layers.Dense(
            units=num_classes, activation=tf.keras.activations.softmax)

    def call(self, inputs, training=None, mask=None):
        x = self.conv1(inputs)
        x = self.bn1(x, training=training)
        x = tf.nn.relu(x)

        x = self.conv2(x)
        x = self.bn2(x, training=training)

        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)
        x = self.block8(x)
        x = self.block9(x)
        x = self.block10(x)
        x = self.block11(x)

        out = self.maxpool14(self.seqconv13(self.seqconv12(x)))
        skip = self.skip15(x)
        x = tf.keras.layers.add([out, skip])

        x = self.conv16(x)
        x = self.bn16(x, training=training)
        x = tf.nn.relu(x)

        x = self.conv17(x)
        x = self.bn17(x, training=training)
        x = tf.nn.relu(x)

        x = self.avgpool(x)

        out = self.fc(x)

        return out

In [10]:
def get_model():
    model = Xception()
    model.build(input_shape=(None, 299, 299, 3))
    model.summary()
    return model

In [11]:
model = get_model()

Model: "xception_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_6 (Conv2D)            multiple                  896       
_________________________________________________________________
batch_normalization_36 (Batc multiple                  128       
_________________________________________________________________
conv2d_7 (Conv2D)            multiple                  18496     
_________________________________________________________________
batch_normalization_37 (Batc multiple                  256       
_________________________________________________________________
down_block_3 (DownBlock)     multiple                  35904     
_________________________________________________________________
down_block_4 (DownBlock)     multiple                  137344    
_________________________________________________________________
down_block_5 (DownBlock)     multiple                  9