In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import activations
from tensorflow.keras.datasets import mnist

In [2]:
# 定义ResNet的残差块
class Residual(tf.keras.Model):
    # 指明残差块的通道数，是否使用1*1卷积，步长
    def __init__(self, num_channels, use_1x1conv=False, strides=1):
        super(Residual, self).__init__()
        # 卷积层
        self.conv1 = tf.keras.layers.Conv2D(num_channels, kernel_size=3, padding='same', strides=strides)
        self.conv2 = tf.keras.layers.Conv2D(num_channels, kernel_size=3, padding='same')
        
        # X 是否使用 1x1 卷积（走短路是否有层 1x1 卷积）
        if use_1x1conv:
            self.conv3 = tf.keras.layers.Conv2D(num_channels, kernel_size=1, strides=strides)
        else:
            self.conv3 = None
            
        # 指明 BN 层 (正则化)
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.bn2 = tf.keras.layers.BatchNormalization()
    
    # 定义前向传播过程
    def call(self, X):
        # 卷积 BN 激活
        Y = activations.relu(self.bn1(self.conv1(X)))
        
        # 卷积 BN
        Y = self.bn2(self.conv2(Y))
         
        # 对输入数据进行 1*1 卷积--保证通道数相同
        if self.conv3:
            X = self.conv3(X)
            
        # 返回与输入相加后激活的结果
        return activations.relu(Y + X)

In [3]:
# ResNet网络中模块的构成
class ResnetBlock(tf.keras.layers.Layer):
    # 网络层的定义：输出通道数（卷积核个数），模块中包含的残差块个数，是否为第一个模块
    def __init__(self, num_channels, num_residuals, first_block=False):
        super(ResnetBlock, self).__init__()
        
        # 模块中的网络层
        self.listLayers=[]
        
        # 遍历模块中所有的层
        for i in range(num_residuals):
            # 若为第一个残差块并且不是第一个模块，则使用1*1卷积，步长为2（目的是减小特征图，并增大通道数）
            if i == 0 and not first_block:
                self.listLayers.append(Residual(num_channels, use_1x1conv=True, strides=2))
            # 否则不使用1*1卷积，步长为1
            else:
                self.listLayers.append(Residual(num_channels))
                
    # 定义前向传播过程
    def call(self, X):
        # 所有层依次向前传播即可
        for layer in self.listLayers.layers:
            X = layer(X)
        return X

In [4]:
# 构建ResNet网络
class ResNet(tf.keras.Model):
    # 初始化：指定每个模块中的残差块的个数
    def __init__(self, num_blocks):
        super(ResNet, self).__init__()
        
         # 输入层：7*7卷积，步长为2
        self.conv = tf.keras.layers.Conv2D(64, kernel_size=7, strides=2, padding='same')
        
        # BN层
        self.bn = tf.keras.layers.BatchNormalization()
        
        # 激活层
        self.relu = tf.keras.layers.Activation('relu')
        
        # 最大池化层
        self.mp = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')
        
        # 第一个block，通道数为64
        self.resnet_block1 = ResnetBlock(64, num_blocks[0], first_block=True)
        
        # 第二个block，通道数为128
        self.resnet_block2 = ResnetBlock(128, num_blocks[1])
        
        # 第三个block，通道数为256
        self.resnet_block3 = ResnetBlock(256, num_blocks[2])
        
        # 第四个block，通道数为512
        self.resnet_block4 = ResnetBlock(512, num_blocks[3])
        
        # 全局平均池化
        self.gap = tf.keras.layers.GlobalAvgPool2D()
        
        # 全连接层：分类
        self.fc = tf.keras.layers.Dense(units=10, activation=activations.softmax)
        
     # 前向传播过程
    def call(self, x):
        x = self.conv(x)  # 卷积
        x = self.bn(x)  # BN
        x = self.relu(x) # 激活
        x = self.mp(x)  # 最大池化
            
        # 残差模块
        x = self.resnet_block1(x)
        x = self.resnet_block2(x)
        x = self.resnet_block3(x)
        x = self.resnet_block4(x)
            
        x = self.gap(x)  # 全局平均池化
        x = self.fc(x)  # 全链接层
        return x

In [5]:
# 模型实例化：指定每个block中的残差块个数 
my_net = ResNet([3, 4, 6, 3])

In [6]:
X = tf.random.uniform(shape=(1, 224, 224, 1))
y = my_net(X)
my_net.summary()

Model: "res_net"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             multiple                  3200      
                                                                 
 batch_normalization (BatchN  multiple                 256       
 ormalization)                                                   
                                                                 
 activation (Activation)     multiple                  0         
                                                                 
 max_pooling2d (MaxPooling2D  multiple                 0         
 )                                                               
                                                                 
 resnet_block (ResnetBlock)  multiple                  223104    
                                                                 
 resnet_block_1 (ResnetBlock  multiple                 1119

In [7]:
# 获取手写数字数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# 训练集数据维度的调整：N H W C
train_images = np.reshape(train_images, (train_images.shape[0], train_images.shape[1], train_images.shape[2], 1))
# 测试集数据维度的调整：N H W C
test_images = np.reshape(test_images, (test_images.shape[0], test_images.shape[1], test_images.shape[2], 1))


In [8]:
# 定义两个方法随机抽取部分样本演示
# 获取训练集数据
def get_train(size):
    index = np.random.randint(0, np.shape(train_images)[0], size)
    resized_images = tf.image.resize_with_pad(train_images[index], 224, 224)
    return resized_images.numpy(), train_labels[index]


# 获取测试集数据
def get_test(size):
    index = np.random.randint(0, np.shape(test_images)[0], size)
    resized_images = tf.image.resize_with_pad(test_images[index], 224, 224)
    return resized_images.numpy(), test_labels[index]

In [9]:
# 获取训练样本和测试样本
train_images, train_labels = get_train(256)
test_images, test_labels = get_test(128)

In [10]:
# 指定优化器，损失函数和评价指标
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.0)

my_net.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])

In [11]:
# 模型训练：指定训练数据，batchsize,epoch,验证集
my_net.fit(train_images, train_labels, batch_size=128, epochs=3, verbose=1, validation_batch_size=0.1)

Epoch 1/3


2024-08-16 09:50:35.581545: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0x16af47af0>

In [12]:
# 指定测试数据
my_net.evaluate(test_images, test_labels, verbose=1)



[19.455408096313477, 0.0703125]