In [None]:
import tensorflow as tf
from tensorflow.keras.layers import (
    Conv2D,
    BatchNormalization,
    Activation,
    MaxPooling2D,
    Add,
    AveragePooling2D,
    Flatten,
    Dense,
    Input,
)
from tensorflow.keras.models import Model


def residual_block(x, filters, kernel_size=3, stride=1, conv_shortcut=True, name=None):
    bn_axis = 3  # for 'channels_last' data format
    if conv_shortcut:
        shortcut = Conv2D(filters, 1, strides=stride, name=name + "_0_conv")(x)
        shortcut = BatchNormalization(axis=bn_axis, name=name + "_0_bn")(shortcut)
    else:
        shortcut = x

    x = Conv2D(
        filters, kernel_size, padding="SAME", strides=stride, name=name + "_1_conv"
    )(x)
    x = BatchNormalization(axis=bn_axis, name=name + "_1_bn")(x)
    x = Activation("relu", name=name + "_1_relu")(x)

    x = Conv2D(filters, kernel_size, padding="SAME", name=name + "_2_conv")(x)
    x = BatchNormalization(axis=bn_axis, name=name + "_2_bn")(x)

    x = Add(name=name + "_add")([shortcut, x])
    x = Activation("relu", name=name + "_out")(x)
    return x


def build_resnet(input_shape=(32, 32, 3), num_classes=10):
    inputs = Input(shape=input_shape, name="input_layer")
    x = Conv2D(64, 3, strides=2, padding="same", name="conv1")(inputs)
    x = BatchNormalization(axis=3, name="bn_conv1")(x)
    x = Activation("relu", name="conv1_relu")(x)
    x = MaxPooling2D(3, strides=2, padding="same", name="pool1")(x)

    x = residual_block(x, 64, conv_shortcut=False, name="conv2_1")
    x = residual_block(x, 64, conv_shortcut=False, name="conv2_2")
    x = residual_block(x, 64, conv_shortcut=False, name="conv2_3")

    x = residual_block(x, 128, stride=2, name="conv3_1")
    x = residual_block(x, 128, conv_shortcut=False, name="conv3_2")
    x = residual_block(x, 128, conv_shortcut=False, name="conv3_3")
    x = residual_block(x, 128, conv_shortcut=False, name="conv3_4")

    x = residual_block(x, 256, stride=2, name="conv4_1")
    x = residual_block(x, 256, conv_shortcut=False, name="conv4_2")
    x = residual_block(x, 256, conv_shortcut=False, name="conv4_3")
    x = residual_block(x, 256, conv_shortcut=False, name="conv4_4")
    x = residual_block(x, 256, conv_shortcut=False, name="conv4_5")
    x = residual_block(x, 256, conv_shortcut=False, name="conv4_6")

    x = residual_block(x, 512, stride=2, name="conv5_1")
    x = residual_block(x, 512, conv_shortcut=False, name="conv5_2")
    x = residual_block(x, 512, conv_shortcut=False, name="conv5_3")

    x = AveragePooling2D(pool_size=(4, 4), name="avg_pool")(x)
    x = Flatten(name="flatten")(x)
    outputs = Dense(num_classes, activation="softmax", name="fc1000")(x)

    model = Model(inputs, outputs, name="resnet34")
    return model


resnet_34 = build_resnet(input_shape=(32, 32, 3), num_classes=10)
resnet_34.summary()
