In [None]:
import tensorflow as tf

In [None]:
class ResidualBlock(tf.keras.layers.Layer):
    def __init__(self, filter_num, stride = 1):
        super(ResidualBlock, self).__init__()
        # 残差块的第一层
        self.conv1 = tf.keras.layers.Conv2D(filters=filter_num, kernel_size=3, strides=stride, padding='same')
        self.batch1 = tf.keras.layers.BatchNormalization()
        self.activation1 = tf.keras.layers.Activation('relu')

        # 残差块的第二层
        self.conv2 = tf.keras.layers.Conv2D(filters=filter_num, kernel_size=3, strides=1, padding='same')
        self.batch2 = tf.keras.layers.BatchNormalization()

        # 判断是否包含downsample
        self.downsample = tf.keras.Sequential([tf.keras.layers.Conv2D(filters=filter_num, kernel_size=1, strides=stride)]) if stride != 1 else lambda x : x
    

    def call(self, x, training = None):
        output = self.conv1(x)
        output = self.batch1(output, training = training)
        output = self.activation1(output)
        output = self.conv2(output)
        output = self.batch2(output, training = training)
        identity = self.downsample(x)
        return tf.nn.relu(tf.keras.layers.add([output, identity]))


class ResidualNetwork(tf.keras.Model):
    def __init__(self, layers, num_classifies):
        super(ResidualNetwork, self).__init__()
        self.first = tf.keras.Sequential([
            tf.keras.layers.Conv2D(filters=64,kernel_size=3, strides=1, input_shape=(28, 28, 1)),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.MaxPool2D(pool_size=2, strides=1, padding='same'),
        ])
        self.reslayers = [self.rbuild(64, layers[0]),
        self.rbuild(128, layers[1], stride=2), #从第二组残差块开始，每一组残差块的第一个残差块存在downsample
        self.rbuild(256, layers[2], stride=2),
        self.rbuild(512, layers[3], stride=2)]
        self.opt = tf.keras.layers.Dense(num_classifies, activation='softmax')

    def call(self, x, training = None):
        x = self.first(x, training = training)
        for i in range(len(self.reslayers)):
            x = self.reslayers[i](x, training = training)
        x = tf.keras.layers.GlobalAveragePooling2D()(x)
        return self.opt(x)

    def rbuild(self, filter_num, blocks, stride = 1):
        resBlock = tf.keras.Sequential(ResidualBlock(filter_num, stride))
        for i in range(blocks - 1):
            resBlock.add(ResidualBlock(filter_num, stride=1))
        return resBlock


In [None]:
resNet18 = ResidualNetwork([2,2,2,2], 10)
resNet18.compile(optimizer='adam', loss="sparse_categorical_crossentropy",metrics=['accuracy'])

In [None]:
data = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = data.load_data()

In [None]:
resNet18.fit(x_train, y_train,batch_size=64, epochs=5)

In [None]:
resNet18.evaluate(x_test, y_test)

In [None]:
resNet18.summary()