In [None]:
import tensorflow as tf
import numpy as np
from tensorflow import keras

In [None]:
def build_resnet_block(
    input_layer, block_num, cnn, channel, strides, not_plain=True, is_50=False
):
    x = input_layer

    identity_mapping = x
    for num_cnn in range(cnn):
        # layer 1
        x = keras.layers.Conv2D(
            filters=channel,
            kernel_size=(1, 1) if is_50 else (3, 3),
            strides=strides,
            kernel_initializer="he_normal",
            padding="same",
            name=f"blcok{block_num}_conv{num_cnn}_1",
        )(x)
        x = keras.layers.BatchNormalization()(x)
        x = keras.layers.Activation("relu")(x)

        # layer 2
        x = keras.layers.Conv2D(
            filters=channel,
            kernel_size=(3, 3),
            strides=1,
            kernel_initializer="he_normal",
            padding="same",
            name=f"blcok{block_num}_conv{num_cnn}_2",
        )(x)
        x = keras.layers.BatchNormalization()(x)

        # layer 3
        if is_50:
            x = keras.layers.Activation("relu")(x)
            x = keras.layers.Conv2D(
                filters=(channel * 4),
                kernel_size=(1, 1),
                strides=1,
                kernel_initializer="he_normal",
                padding="same",
                name=f"blcok{block_num}_conv{num_cnn}_3",
            )(x)
            x = keras.layers.BatchNormalization()(x)

        if not_plain:
            identity_mapping = keras.layers.Conv2D(
                filters=(channel * 4) if is_50 else channel,
                kernel_size=(1, 1),
                strides=strides,
                padding="same",
                name=f"blcok{block_num}_identity_conv{num_cnn}",
            )(identity_mapping)
            identity_mapping = keras.layers.BatchNormalization()(identity_mapping)
            x = keras.layers.Add()([x, identity_mapping])

            if strides != 1:
                strides = 1

        x = keras.layers.Activation("relu")(x)

    return x

In [None]:
def build_resnet(input_shape, is_50=False, not_plain = True):
    input_layer = keras.layers.Input(shape = input_shape)
    
    setting = {'cnn_list': [3, 4, 6, 3], 'channel_list': [64, 128, 256, 512]}
    
    resnet_output = keras.layers.Rescaling(1.0 / 255)(input_layer)
    resnet_output = keras.layers.Conv2D(64, (7, 7),
                      strides=2,
                      padding='same',
                      kernel_initializer='he_normal')(resnet_output)
    resnet_output = keras.layers.BatchNormalization()(resnet_output)
    resnet_output = keras.layers.Activation('relu')(resnet_output)
    resnet_output = keras.layers.MaxPooling2D((3, 3), strides=2)(resnet_output)
    
    for i, (cnn, channel) in enumerate(zip(setting['cnn_list'], setting['channel_list'])):
        strides = 1 if i == 0 else 2
        
        resnet_output = build_resnet_block(resnet_output,
                                           block_num = i,
                                           cnn = cnn,
                                           channel = channel,
                                           strides = strides,
                                           not_plain = not_plain,
                                           is_50 = is_50)
    
    resnet_output = keras.layers.AveragePooling2D()(resnet_output)
    resnet_output = keras.layers.Flatten()(resnet_output)
    resnet_output = keras.layers.Dense(1, activation='sigmoid')(resnet_output)
    
    model = keras.Model(inputs = input_layer,
                        outputs = resnet_output)
    
    return model

In [None]:
resnet34 = build_resnet(input_shape = (224, 224, 3))
resnet34.summary()