## ResNet

这是一个由深度学习领域的专家Kaiming He等人开发并获得2015年ILSVRC竞赛冠军的残差网络。原版神经网络使用了一个由152层组成的非常深的卷积神经网络。能够训练如此深层网络的关键是使用了跳过连接（快捷连接）：输入到一个层中的信号也被添加到位于堆栈上方的层的输出端。在这里，我们会试图使用它的核心思想，复刻一个简易版本的ResNet18。

ResNet由卷积层、最大池化层、平均池化层、全连接层构成。在最开始，输入层会经过卷积层和最大池化层，接下来就会进入深层。

我们可以人为定义**残差单元**。残差单元指的是，输入的特征会进入一个卷积层，而后进行Batch Normalization和ReLU激活函数，并再进入一个卷积层，再进行Batch Normalization。这样得到的内容会被我们称为残差，只要将残差加上我们一开始输入的特征，我们即可获得最后的结果，我们会将该结果进行ReLU的激活，从而进行下一轮的残差单元。

之所以要进行残差学习，是因为如果不进行残差学习，后期由于神经网络过深，可能会导致梯度消失、梯度爆炸、神经网络恶化等问题。而通过残差学习，我们可以避免神经网络的恶化，保留住先前的最优解。

### 1. 导入必要模块

In [8]:
import time as time
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

### 2. 引入数据集

在这里，我们直接使用tensorflow中自带的数据集。

In [2]:
# These variables are all in type of numpy.
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
print(train_images.shape)
print(train_labels.shape)
print(test_images.shape)
print(test_labels.shape)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
(60000, 28, 28)
(60000,)
(10000, 28, 28)
(10000,)


### 3. 数据预处理

将28\*28的图片填充到32\*32的规模，以便于进行输入。将图片变为3维，以便于神经网络的训练。同时，将分类变为one-hot编码，以便于后续在神经网络训练中可以使用categorical cross-entropy损失函数。

In [3]:
train_images_32 = np.zeros((60000, 32, 32), dtype=train_images.dtype)
test_images_32 = np.zeros((10000, 32, 32), dtype=test_images.dtype)

start_row = (32 - 28) // 2
start_col = (32 - 28) // 2
for i in range(60000):
  train_images_32[i][start_row:start_row+28, start_col:start_col+28] = train_images[i]
for i in range(10000):
  test_images_32[i][start_row:start_row+28, start_col:start_col+28] = test_images[i]

train_images_32 = train_images_32.reshape((60000, 32, 32, 1)).astype('float32') / 255
test_images_32 = test_images_32.reshape((10000, 32, 32, 1)).astype('float32') / 255

train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

print(train_images_32.shape)
print(test_images_32.shape)
print(train_labels.shape)
print(test_labels.shape)

(60000, 32, 32, 1)
(10000, 32, 32, 1)
(60000, 10)
(10000, 10)


### 4. 搭建神经网络

In [4]:
# 残差单元
def residual_block(x, filters, stride=1):
  shortcut = x

  x = layers.Conv2D(filters, 3, strides=stride, padding='same')(x)
  x = layers.BatchNormalization()(x)
  x = layers.ReLU()(x)

  x = layers.Conv2D(filters, 3, strides=1, padding='same')(x)
  x = layers.BatchNormalization()(x)

  if stride != 1 or shortcut.shape[-1] != filters:
    shortcut = layers.Conv2D(filters, 1, strides=stride, padding='same')(shortcut)
    shortcut = layers.BatchNormalization()(shortcut)

  x = layers.add([x, shortcut])
  x = layers.ReLU()(x)

  return x

# ResNet18的搭建
def Resnet_18(input_shape=(32, 32, 1), num_classes=10):
  input_tensor = tf.keras.Input(shape=input_shape)

  x = layers.Conv2D(64, 7, strides=2, padding='same')(input_tensor)
  x = layers.ReLU()(x)
  x = layers.MaxPooling2D(3, strides=2, padding='same')(x)

  residual_blocks = [2, 2, 2, 2]
  filters_list = [64, 128, 256, 512]

  for residual, num_blocks, filters in zip(range(len(residual_blocks)), residual_blocks, filters_list):
    for block in range(num_blocks):
      stride = 2 if residual > 0 and block == 0 else 1
      x = residual_block(x, filters, stride=stride)

  # x = layers.AveragePooling2D((7, 7), padding='valid', strides=1)(x)
  # x = layers.AveragePooling2D((2, 2), padding='valid', strides=1)(x)
  x = layers.Flatten()(x)
  x = layers.Dense(1000, activation='relu')(x)
  output = layers.Dense(num_classes, activation='softmax')(x)

  model = models.Model(inputs=input_tensor, outputs=output)
  return model

### 5. 编译模型

In [7]:
model = Resnet_18()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

### 6. 训练模型

In [9]:
start_time = time.time()
model.fit(train_images_32, train_labels, epochs=5, batch_size=64, validation_split=0.2)
end_time = time.time()
print("Training Time:", end_time - start_time, "seconds")

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Training Time: 148.23498797416687 seconds


In [None]:
test_loss, test_acc = model.evaluate(test_images_32, test_labels)
print(f'Test accuracy: {test_acc}')

Test accuracy: 0.9807999730110168
