In [43]:
import tensorflow as tf 
from tensorflow import keras

# 因为默认tf会吧全部显存都占用光, 所以在进行卷积计算的时候回报错提示无法初始化cuDNN, 所以这里我们需要手动的将tf设置成按需占用内存
# 这样就避免了卷积计算错误的问题
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)



In [44]:
# 定义一个ResNet的基本传递块
class BasicBlock(keras.layers.Layer):
    def __init__(self,filter_num,stride = 1):
        super(BasicBlock,self).__init__()
        self.conv1 = keras.layers.Conv2D(filters = filter_num,kernel_size=(3,3),strides=stride,padding='SAME') # kernel的大小选3*3或者1*1都可以
        self.bn1 = keras.layers.BatchNormalization() #这里是一个非线性层
        self.relu = keras.layers.Activation('relu')
        self.conv2 = keras.layers.Conv2D(filters=filter_num,kernel_size=(3,3),strides=1,padding='SAME')
        self.bn2 = keras.layers.BatchNormalization()
        
        if stride != 1:
            self.shortcut = keras.Sequential()# 短接层
            self.shortcut.add(keras.layers.Conv2D(filters=filter_num,kernel_size = (1,1),strides = stride))
        else:
            self.shortcut = lambda x:x
        
    def call(self,inputs,training = None):
        out = self.conv1(inputs)
        out = self.bn1(out,training=training)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out,training=training)
        downsample = self.shortcut(inputs)
        output = keras.layers.add([downsample,out])# 两个tensor相加最好用这种方式
        output = self.relu(output)
        return output
        

In [45]:
# 根据上面的基本传递层,构建ResNet的基本单元:block并构建挣个模型
# 一个resblock中包含若干个BasicBlock, 一般是2~3个
class ResNet(keras.Model):
    def __init__(self,layer_dimensions, num_class): 
        # layer_dimensions参数定义了模型的结构,例如 [2,2,2,2]就表示了模型中又4个resblock,低1,2,3,4个resblock中的粉笔包含2,2,2,2个BasicBlock
        # num_class表示输出的分类的数量
        super(ResNet,self).__init__()
        # 创建网络输入处的预处理卷积层
        self.root_preprepare = keras.Sequential([keras.layers.Conv2D(64,(3,3),strides = (1,1), padding = 'SAME'),
                                                keras.layers.BatchNormalization(),
                                                keras.layers.Activation('relu'),
                                                keras.layers.MaxPool2D(pool_size = (2,2),strides = (1,1),padding = 'SAME')
                                                ])
        self.layer0 = self.build_resblock(filter_num = 64,block_num = layer_dimensions[0])
        self.layer1 = self.build_resblock(filter_num = 128,block_num = layer_dimensions[1],stride = 2)
        self.layer2 = self.build_resblock(filter_num = 256,block_num = layer_dimensions[2],stride = 2)
        self.layer3 = self.build_resblock(filter_num = 512,block_num = layer_dimensions[3],stride = 2)
        # 我们这里嘉定是4个resblock
        
        # 之前的所有的层计算完成之后对于后边的层来说前面计算得到的产物的大小维度是不确定的
        # 我们假设之前的计算得到的是[b,512,h,w]
        # 对齐进行GlobalPooling处理, 得到的输出是将w,h整个图片进行平均值
        # 得到的结果就是[b,512,1]大小的
        self.avgpooling = keras.layers.GlobalAveragePooling2D()
        self.fc = keras.layers.Dense(num_class)
        
        # 这样完成了所有层的定义
        
     
    def call(self,inputs,training = None):
        x = self.root_preprepare(inputs,training=training)
#         print(x)
        x = self.layer0(x,training=training)
        x = self.layer1(x,training=training)
        x = self.layer2(x,training=training)
        x = self.layer3(x,training=training)
        
        x = self.avgpooling(x)

        x = self.fc(x)
        # 我们的分类任务是100类的, 所以输出的大小维度是[b,100]
        return x
        
    
    def build_resblock(self,filter_num,block_num,stride = 1):
        resblocks = keras.Sequential()
        resblocks.add(BasicBlock(filter_num,stride)) # 只给第一个传递块下采样的功能, 后续的层全部都设定为不下采样的
        for i in range(1,block_num):
            resblocks.add(BasicBlock(filter_num,stride = 1)) # 后续的传递块全部都不下采样
        return resblocks

In [46]:
# 定义不同深度以及不同输出大小维度的ResNet的接口
def ResNet18():
    return ResNet(layer_dimensions = [2,2,2,2], num_class = 100)

def ResNet34():
    return ResNet(layer_dimensions = [3,4,6,3], num_class = 100)

In [47]:
(data_train,label_train),(data_test,label_test) = keras.datasets.cifar100.load_data()
label_train = tf.squeeze(label_train,axis=1)
label_test = tf.squeeze(label_test,axis=1)
data_train.shape,label_train.shape 

((50000, 32, 32, 3), TensorShape([50000]))

In [48]:
def preprocess(x,y):
    x = tf.cast(x,dtype=tf.float32)/255. 
    y = tf.cast(y,dtype=tf.int32)
    return x,y

In [49]:
def main():
    train_db = tf.data.Dataset.from_tensor_slices((data_train,label_train))
    train_db = train_db.shuffle(10000).map(preprocess).batch(256)
    
    test_db = tf.data.Dataset.from_tensor_slices((data_test,label_test))
    test_db = test_db.shuffle(10000).map(preprocess).batch(256)
    
    model = ResNet18()
#     model.build(input_shape = (None,32,32,3))
    optimizer = keras.optimizers.Adam(learning_rate=1e-4)
    
    for epoch in range(50):
        for step,(x,y) in enumerate(train_db):
            with tf.GradientTape() as tape:
                logits = model(x,training=True)
                y_onehot = tf.one_hot(y,depth=100)
                loss = tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True)
                loss = tf.reduce_mean(loss)
                
            grads = tape.gradient(loss,model.trainable_variables)
            optimizer.apply_gradients(zip(grads,model.trainable_variables))
            if step%100 == 0:
                print(epoch,step,'loss: ',loss.numpy())
        
        total_num = 0
        total_correct = 0
        for (x,y) in test_db:
            logits = model(x)
            prob = tf.nn.softmax(logits,axis=1)
            pred = tf.argmax(prob,axis=1)
            pred = tf.cast(pred,dtype=tf.int32)
            
            correct = tf.cast(tf.equal(pred,y),dtype=tf.int32)
            correct = tf.reduce_sum(correct)
            total_correct+= int(correct)
            total_num += x.shape[0]
        acc = total_correct/total_num
        print(epoch,'Accuracy: ',acc)

In [50]:
main()

0 0 loss:  5.421026
0 100 loss:  3.2155776
0 Accuracy:  0.0139
1 0 loss:  3.039382
1 100 loss:  2.644591
1 Accuracy:  0.0199
2 0 loss:  2.412082
2 100 loss:  2.218904
2 Accuracy:  0.2274
3 0 loss:  1.9199833
3 100 loss:  1.6484944
3 Accuracy:  0.3383
4 0 loss:  1.4039509
4 100 loss:  1.1718652
4 Accuracy:  0.3467
5 0 loss:  0.9372132
5 100 loss:  0.6304083
5 Accuracy:  0.3373
6 0 loss:  0.42826736
6 100 loss:  0.18909991
6 Accuracy:  0.3653
7 0 loss:  0.14116764
7 100 loss:  0.07670945
7 Accuracy:  0.3813
8 0 loss:  0.067004085
8 100 loss:  0.036363807
8 Accuracy:  0.3897
9 0 loss:  0.020149644
9 100 loss:  0.030588757
9 Accuracy:  0.3868
10 0 loss:  0.017170336
10 100 loss:  0.013866104
10 Accuracy:  0.3853
11 0 loss:  0.010362391
11 100 loss:  0.01578324
11 Accuracy:  0.3888
12 0 loss:  0.011500987
12 100 loss:  0.02702912
12 Accuracy:  0.3738
13 0 loss:  0.011322264
13 100 loss:  0.006753711
13 Accuracy:  0.3125
14 0 loss:  0.028831156
14 100 loss:  0.2970064
14 Accuracy:  0.1873
15

KeyboardInterrupt: 