# The ResNet architecture, from identity block to the whole structure

* The first part of building the ResNet architecture is the identity block,

Which will be the base for the whole structure.

In [1]:
#!/usr/bin/env python3
"""
Makes the ResNet model identity block.
"""


import tensorflow.keras as keras


def identity_block(A_prev, filters):
    """
    Identity block function.
    
    :param A_prev: output of the previous layer.
    
    :param filters: number of filters in tuple form carrying:
    F11 - filter for 1x1 convolution
    F3 - filter size for 3x3 convolution
    F22 - filter for second 1z1 convolution
    
    All convolutions will be followed by Batch normalization, and ReLU activation.
    
    All weights will use normal initialization.
    
    Returns: Activated output of the block.
    """
    F11, F3, F12 = filters
    initializer = keras.initializers.he_normal()
    activation = keras.activations.relu
    
    layers = keras.layers
    
    Conv_1x1 = layers.Conv2D(
        F11,
        (1, 1),
        padding='same',
        kernel_initializer=initializer,   
    )(A_prev)
    Batch_1x1 = layers.BatchNormalization(axis=3)(Conv_1x1)
    ReLU_1x1 = layers.Activation(activation)(Batch_1x1)
    
    Conv_3x3 = layers.Conv2D(
        F3,
        (3, 3),
        padding='same',
        kernel_initializer=initializer,
    )(ReLU_1x1)
    Batch_3x3 = layers.BatchNormalization(axis=3)(Conv_3x3)
    ReLU_3x3 = layers.Activation(activation)(Batch_3x3)
    
    Conv_1x1_2 = layers.Conv2D(
        F12,
        (1, 1),
        padding='same',
        kernel_initializer=initializer,
    )(ReLU_3x3)
    Batch_1x1_2 = layers.BatchNormalization(axis=3)(Conv_1x1_2)
    
    pre_output = layers.Add()([Batch_1x1_2, A_prev])
    
    output = layers.Activation(activation)(pre_output)
    
    return output


In [3]:
# Main function for Identity Block

if __name__ == '__main__':
    X = keras.Input(shape=(224, 224, 256))
    Y = identity_block(X, [64, 64, 256])
    model = keras.models.Model(inputs=X, outputs=Y)
    model.summary()