# Tensorflow ResNet


In [4]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from typing import List


In [5]:
original_resnet = tf.keras.applications.resnet.ResNet50()
original_resnet.summary()

Model: "resnet50"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv1_pad (ZeroPadding2D)      (None, 230, 230, 3)  0           ['input_2[0][0]']                
                                                                                                  
 conv1_conv (Conv2D)            (None, 112, 112, 64  9472        ['conv1_pad[0][0]']              
                                )                                                                 
                                                                                           

In [6]:
def residual_block(tensor, name: str, filters: int, strides: tuple = (1, 1), is_conv_shortcut: bool = True):
    """A residual block.
    Args:
        tensor: input tensor
        name: string, block label.
        filters: integer, filters of the bottleneck layer.
        strides: default 1, stride of the first layer and shortcut layer.
        conv_shortcut: default True, use convolution shortcut if True, otherwise identity shortcut.
    Returns:
        Output tensor for the residual block.
    """

    shortcut = tensor
    if is_conv_shortcut:
        shortcut = tf.keras.layers.Conv2D(
            name=f"{name}x0_SC_Conv", filters=4*filters, kernel_size=(1, 1), strides=strides, padding='VALID')(tensor)
        shortcut = tf.keras.layers.BatchNormalization(
            name=f"{name}x0_SC_BN")(shortcut)

    # 1st Conv
    out = tf.keras.layers.Conv2D(name=f"{name}x1_Conv", filters=filters, kernel_size=(
        1, 1), strides=strides, padding='VALID')(tensor)
    out = tf.keras.layers.BatchNormalization(name=f"{name}x1_BN")(out)
    out = tf.keras.layers.ReLU(name=f"{name}x1_ReLU")(out)

    # 2nd Conv
    out = tf.keras.layers.Conv2D(name=f"{name}x2_Conv", filters=filters, kernel_size=(
        3, 3), strides=(1, 1), padding='SAME')(out)
    out = tf.keras.layers.BatchNormalization(name=f"{name}x2_BN")(out)
    out = tf.keras.layers.ReLU(name=f"{name}x2_ReLU")(out)

    # 3rd Conv
    out = tf.keras.layers.Conv2D(
        name=f"{name}x3_Conv", filters=4*filters, kernel_size=(1, 1), strides=(1, 1), padding='VALID')(out)
    out = tf.keras.layers.BatchNormalization(name=f"{name}x3_BN")(out)

    # Connection
    out = tf.keras.layers.Add(name=f"{name}x3_Add")([shortcut, out])
    out = tf.keras.layers.ReLU(name=f"{name}x3_ReLU")(out)

    return out


In [7]:
def residual_stack(tensor, filters: int, num_blocks: int, is_first_block: bool = False, name=None):
    """A set of stacked residual blocks.
    Args:
        tensor: input tensor
        filters: integer, filters of the bottleneck layer in a block.
        num_blocks: integer, blocks in the stacked blocks.
        name: string, stack label.
    Returns:
        Output tensor for the stacked blocks.
    """

    if is_first_block:
        out = residual_block(tensor=tensor, name=f"{name}_Block1", filters=filters)
    else:
        out = residual_block(
            tensor=tensor, name=f"{name}_Block1", filters=filters, strides=(2, 2))

    for i in range(2, num_blocks+1):
        out = residual_block(tensor=out, name=f"{name}_Block{i}", filters=filters, is_conv_shortcut=False)

    return out


In [8]:
def ResNet(name, blocks: List[int], num_classes: int):
    def stack_fn(tensor):
        out = tf.keras.layers.Conv2D(name="Head_Conv", filters=64, kernel_size=(
            7, 7), strides=(2, 2), padding="same")(tensor)
        out = tf.keras.layers.MaxPool2D(name="Head_MaxPool", pool_size=(
            3, 3), strides=(2, 2), padding="same")(out)
        out = residual_stack(tensor=out, name="Stack1",
                             is_first_block=True, filters=64, num_blocks=blocks[0])
        out = residual_stack(tensor=out, name="Stack2",
                             filters=128, num_blocks=blocks[1])
        out = residual_stack(tensor=out, name="Stack3",
                             filters=256, num_blocks=blocks[2])
        out = residual_stack(tensor=out, name="Stack4",
                             filters=512, num_blocks=blocks[3])
        out = tf.keras.layers.GlobalAveragePooling2D(
            name="Tail_GlobalAVGPool")(out)
        out = tf.keras.layers.Flatten(name="Tail_Flatten")(out)
        out = tf.keras.layers.Dense(
            name="Tail_Dense", units=num_classes, activation='softmax')(out)
        return out

    x = tf.keras.layers.Input(name="Head_Input", shape=(224, 224, 3))
    model = tf.keras.models.Model(name=name, inputs=[x], outputs=stack_fn(x))
    return model


In [9]:
resnet_50 = ResNet(name="ResNet101", blocks=[3, 4, 6, 3], num_classes=1000)
resnet_50.summary()


Model: "ResNet101"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 Head_Input (InputLayer)        [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 Head_Conv (Conv2D)             (None, 112, 112, 64  9472        ['Head_Input[0][0]']             
                                )                                                                 
                                                                                                  
 Head_MaxPool (MaxPooling2D)    (None, 56, 56, 64)   0           ['Head_Conv[0][0]']              
                                                                                          