In [1]:
from datetime import datetime
import numpy as np
import tensorflow as tf
from tensorflow import keras

In [47]:
# 自定义BisicBlock

class BasicBlock(keras.layers.Layer):
    def __init__(self, filter_num, strides=1):
        super(BasicBlock, self).__init__()
        self.conv1 = keras.layers.Conv2D(filter_num, kernel_size=(3, 3), strides=strides, padding='same')
        self.bn1 = keras.layers.BatchNormalization()
        self.relu = keras.layers.Activation('relu')
        
        self.conv2 = keras.layers.Conv2D(filter_num, kernel_size=(3, 3), strides=1, padding='same')
        self.bn2 = keras.layers.BatchNormalization()
        
        if strides != 1:
            self.down_sample = keras.Sequential()
            self.down_sample.add(keras.layers.Conv2D(filter_num, kernel_size=(1, 1), strides=strides))
            self.down_sample.add(keras.layers.BatchNormalization())
            
        else:
            self.down_sample = lambda x: x
        
        self.strides = strides
    
    def call(self, input_, training=None):
        
        conv1 = self.conv1(input_)
        bn1 = self.bn1(conv1)
        relu = self.relu(bn1)
        
        conv2 = self.conv2(relu)
        bn2 = self.bn2(conv2)
        
        residual = self.down_sample(input_)
        
        add = keras.layers.add([bn2, residual])
        out = tf.nn.relu(add)
        return out
    

In [78]:
# 自定义ResNet
class ResNet(keras.Model):
    def __init__(self, layer_dim, num_classes=100): # layer_dim = [2, 2, 2, 2]
        super(ResNet, self).__init__()
        # 根
        self.stem = keras.Sequential([
            keras.layers.Conv2D(64, kernel_size=(3, 3), strides=(1, 1)),
            keras.layers.BatchNormalization(),
            keras.layers.Activation('relu'),
            keras.layers.MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding='same')
        ])
        
        self.layer1 = self.build_block(64, layer_dim[0])
        self.layer2 = self.build_block(128, layer_dim[1], strides=2)
        self.layer3 = self.build_block(256, layer_dim[2], strides=2)
        self.layer4 = self.build_block(512, layer_dim[3], strides=2)
        
        # output: [b, 512, h, w],
        self.avg_pool = keras.layers.GlobalAveragePooling2D()
        self.fc = keras.layers.Dense(num_classes)
     
    def call(self, input_, training=None):
        x = self.stem(input_)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avg_pool(x)
        x = self.fc(x)
        
        return x
    
    def build_block(self, filter_num, blocks, strides=1):
        res_blocks = keras.Sequential()
        
        # 第一层可能是下采样层
        res_blocks.add(BasicBlock(filter_num, strides))
        for _ in range(1, blocks):
            res_blocks.add(BasicBlock(filter_num, strides=1))   
        return res_blocks
    

# class ResNet(keras.Model):


#     def __init__(self, layer_dims, num_classes=100): # [2, 2, 2, 2]
#         super(ResNet, self).__init__()

#         self.stem = keras.Sequential([keras.layers.Conv2D(64, (3, 3), strides=(1, 1)),
#                                 keras.layers.BatchNormalization(),
#                                 keras.layers.Activation('relu'),
#                                 keras.layers.MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding='same')
#                                 ])

#         self.layer1 = self.build_resblock(64,  layer_dims[0])
#         self.layer2 = self.build_resblock(128, layer_dims[1], stride=2)
#         self.layer3 = self.build_resblock(256, layer_dims[2], stride=2)
#         self.layer4 = self.build_resblock(512, layer_dims[3], stride=2)

#         # output: [b, 512, h, w],
#         self.avgpool = keras.layers.GlobalAveragePooling2D()
#         self.fc = keras.layers.Dense(num_classes)





#     def call(self, inputs, training=None):

#         x = self.stem(inputs)

#         x = self.layer1(x)
#         x = self.layer2(x)
#         x = self.layer3(x)
#         x = self.layer4(x)

#         # [b, c]
#         x = self.avgpool(x)
#         # [b, 100]
#         x = self.fc(x)

#         return x



#     def build_resblock(self, filter_num, blocks, stride=1):

#         res_blocks = keras.Sequential()
#         # may down sample
#         res_blocks.add(BasicBlock(filter_num, stride))

#         for _ in range(1, blocks):
#             res_blocks.add(BasicBlock(filter_num, strides=1))

#         return res_blocks

In [79]:
def resnet18():
    return ResNet([2, 2, 2, 2])

In [51]:
def preprocess_data(x, y):
    x = tf.cast(x, dtype=tf.float32) / 255
    y = tf.cast(y, dtype=tf.int32)
    
    return x, y

In [52]:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

In [53]:
batch_size = 128

db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db_train = db_train.map(preprocess_data).shuffle(10000).batch(batch_size)

db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.map(preprocess_data).batch(batch_size)

In [57]:
sample = next(iter(db_train))
print('sample:', sample[0].shape, sample[1].shape)

sample: (128, 32, 32, 3) (128, 1)


In [83]:
model = resnet18()

model.build(input_shape=(None, 32, 32, 3))

# optimizer = keras.optimizers.Adam(1e-3)

model.summary()

Model: "res_net_21"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential_128 (Sequential)  multiple                  2048      
_________________________________________________________________
sequential_129 (Sequential)  multiple                  148736    
_________________________________________________________________
sequential_130 (Sequential)  multiple                  527488    
_________________________________________________________________
sequential_132 (Sequential)  multiple                  2103552   
_________________________________________________________________
sequential_134 (Sequential)  multiple                  8401408   
_________________________________________________________________
global_average_pooling2d_15  multiple                  0         
_________________________________________________________________
dense_18 (Dense)             multiple                  5

In [86]:

# 由于太慢了 就不运行了
for epoch in range(30):
    for step, (x, y) in enumerate(db_train):
        with tf.GradientTape() as tap:
            # [b, 32, 32, 3] => [b, 1, 1, 512]
            # [b, 1, 1, 512] => [b, 512]
            # [b, 512] => [b, 100]
            logits = model(x)
            y_true = tf.one_hot(y, depth=100)
            
            loss = tf.losses.categorical_crossentropy(y_true, logits, from_logits=True)
            loss = tf.reduce_mean(loss)
            
        grads = tap.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        
        if step %100 == 0:
            print(epoch, step, 'loss:', float(loss))
        
        # 测试数据
        total_num, total_correct = 0, 0
        for x, y in db_test:
#             out = tf.reshape(x, [-1, 512])
            y_pred = model(x)
            y_prob = tf.nn.softmax(y_pred, axis=1)
            y_pred = tf.argmax(y_prob, axis=1, output_type=tf.int32)
            y_true = tf.squeeze(y, axis=1)
            result = tf.equal(y_pred, y_true)
            correct = tf.reduce_sum(tf.cast(result, dtype=tf.int32))
            total_correct += correct
            total_num += x.shape[0]
        acc = total_correct / total_num
        print(epoch, 'acc:', acc)
            
            

0 0 loss: 4.603629112243652
0 acc: tf.Tensor(0.0, shape=(), dtype=float64)
0 acc: tf.Tensor(0.0, shape=(), dtype=float64)


KeyboardInterrupt: 