In [9]:
import tensorflow as tf


def get_res_model() -> tf.keras.Model:
    model_input = tf.keras.Input((256, 256, 3))

    module_input = model_input
    module_input_channel = 3
    for channel in [32, 32]:
        x = tf.keras.layers.Conv2D(channel, 3, padding="same", activation="relu")(module_input)
        x = tf.keras.layers.Conv2D(channel, 3, padding="same", activation="relu")(x)
        module_input = tf.pad(module_input,[[0,0],[0,0], [0,0], [0, channel - module_input_channel]])
        module_output = tf.keras.layers.Add()([x, module_input])

        
        module_input = module_output
        module_input_channel = channel

    x = tf.keras.layers.Flatten()(module_output)
    model_output = tf.keras.layers.Dense(2, activation="sigmoid")(x)

    return tf.keras.Model(inputs=model_input, outputs=model_output, name="Residual Model")

res_model = get_res_model()
res_model.summary()

Model: "Residual Model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_8 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_32 (Conv2D)             (None, 256, 256, 32  896         ['input_8[0][0]']                
                                )                                                                 
                                                                                                  
 conv2d_33 (Conv2D)             (None, 256, 256, 32  9248        ['conv2d_32[0][0]']              
                                )                                                    

In [10]:
def get_xception_model():
    model_input = tf.keras.Input((256, 256, 3))

    block_input = model_input
    for channel in [64, 128, 256]:
        x = tf.keras.layers.BatchNormalization()(block_input)
        x = tf.keras.layers.ReLU()(x)
        x = tf.keras.layers.SeparableConv2D(channel, 3, padding="same", use_bias=False)(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        x = tf.keras.layers.MaxPool2D()(x)
        x = tf.keras.layers.Conv2D(channel, 1, padding="same", use_bias=False)(x)
        
        y = tf.keras.layers.Conv2D(channel, 1, strides=2, padding="same")(block_input)

        block_output = tf.keras.layers.Add()([x, y])

        block_input = block_output
    
    x = tf.keras.layers.Flatten()(block_output)
    model_output = tf.keras.layers.Dense(2, activation="sigmoid")(x)
    return tf.keras.Model(inputs=model_input, outputs=model_output, name="Xception Model")

xception_model = get_xception_model()
xception_model.summary()


Model: "Xception Model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_9 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 batch_normalization_16 (BatchN  (None, 256, 256, 3)  12         ['input_9[0][0]']                
 ormalization)                                                                                    
                                                                                                  
 re_lu_16 (ReLU)                (None, 256, 256, 3)  0           ['batch_normalization_16[0][0]'] 
                                                                                     