# 导入函数库

In [1]:
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras  # 在TF2.0版本中,已经是自带Keras了,所以不需要额外安装
import tensorflow_model_optimization as tfmot  # 导入TF2.0的模型优化函数库,降低模型优化难度,相当于调用借口解决问题
import zipfile
import tempfile
import tensorflow_datasets as tfds # 这个是之前说过的Tensorflow Datasets
%load_ext tensorboard

# 如果出现显存不够的错误,把这个代码加上

可以展示下不加这个出现错误的情形

In [2]:
# 加入下面这个代码片段主要是因为TF2.0对RTX20系列显卡支持不是很好,容易爆显存,所以设置成用多少占多少显存,而不是一次性全占了(默认)
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

# 定义网络结构

## 一些参数设置

In [3]:
layers = tf.keras.layers
models = tf.keras.models
keras_utils = tf.keras.utils

## 定义LeNet5用于猫狗分类

In [4]:
def LeNet5(include_top=False,
         weights='imagenet',
         input_shape=(32, 32, 3),
         pooling='avg',
         classes=2,
         **kwargs):
    img_input = layers.Input(shape=input_shape)  # 输入节点
    
    # Block 1
    x = layers.Conv2D(64, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block1_conv1')(img_input)
    x = layers.AveragePooling2D()(x)
    x = layers.Conv2D(64, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block1_conv2')(x)
    x = layers.AveragePooling2D()(x)
    
    if pooling == 'avg':
        x = layers.Flatten()(x)
        x = layers.Dense(1024, activation='relu', name='fc1')(x)
        x = layers.Dense(128, activation='relu', name='fc2')(x)
        x = layers.Dense(classes, activation='softmax', name='predictions')(x) # 此处是10分类，而不是ImageNet的1000分类
    
    inputs = img_input  # inputs是输入节点, x是输出节点
    model = models.Model(inputs, x , name='vgg16')  # 生成一个Model, 需要指定输入和输出
    
    return model

## 声明一个LeNet5模型实例

In [5]:
model = LeNet5()
# model.load_weights('./baseline.h5')

## 这个函数可以用于查看网络结构和参数量

In [6]:
model.summary()

Model: "vgg16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 32, 32, 3)]       0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 32, 32, 64)        1792      
_________________________________________________________________
average_pooling2d (AveragePo (None, 16, 16, 64)        0         
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 16, 16, 64)        36928     
_________________________________________________________________
average_pooling2d_1 (Average (None, 8, 8, 64)          0         
_________________________________________________________________
flatten (Flatten)            (None, 4096)              0         
_________________________________________________________________
fc1 (Dense)                  (None, 1024)              419532

In [7]:
log_dir = './logs'
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

## 对特征提取层进行冻结,加快模型训练速度

可以展示下不加的情形，更直观的说明为什么需要训练更长时间

In [8]:
length = len(model.layers)  # 查看模型有多少层Layer
print(length)

9


In [9]:
# 对模型的特征提取层进行冻结,可以加快模型训练速度
# for i in range(length - 2): # 我们看下上面的model.summary()，可以发现最后2层是全连接，所以不冻结
#     model.layers[i].trainable = False # 对该层Layer进行冻结

In [10]:
for i in range(length):
    print(model.layers[i].trainable)

True
True
True
True
True
True
True
True
True


## 进行模型训练

In [11]:
def convert(image, label):
    image = tf.image.convert_image_dtype(image, tf.float32) # Cast and normalize the image to [0,1]
    image = tf.image.resize_with_crop_or_pad(image, 34, 34) # Add 6 pixels of padding
    image = tf.image.random_crop(image, size=[32, 32, 3]) # Random crop back to 32x32
    return image, label

def augment(image,label):
    image,label = convert(image, label)
#     image = tf.image.random_flip_left_right(image)
#     image = tf.image.random_flip_up_down(image)
#     image = tf.image.random_brightness(image, max_delta=0.5) # Random brightness
    return image,label

In [12]:
# 此处使用tfds的方式构建data pipeline
raw_train, metadata = tfds.load(
    'cats_vs_dogs', # 数据集名称，这个是手势识别分类数据集，共3个类别
    split='train', # 这里的raw_test和split的'test'对应，raw_train和split的'train'对应
    with_info=True, # 这个参数和metadata对应
    as_supervised=True, # 这个参数的作用是返回tuple形式的(input, label),举个例子，raw_test=tuple(input, label)
    shuffle_files=True  # 对数据进行扰乱操作，可以自己体会下设置成False时，下面imshow的时候的结果差别
)
raw_test = raw_train

IMG_SIZE = 32 # All images will be resized to 300X300

BATCH_SIZE = 16
SHUFFLE_BUFFER_SIZE = 5000

# 可以体验下这里是否加prefetch(tf.data.experimental.AUTOTUNE)和cache()的区别，对训练速度，以及CPU负载有影响
train_batches = raw_train.shuffle(SHUFFLE_BUFFER_SIZE, reshuffle_each_iteration=True).map(augment).batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)
test_batches = raw_test.map(convert).batch(BATCH_SIZE)

# compile模型并训练
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
             loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
             metrics=['accuracy'])

In [13]:
model.fit(
    train_batches,
    epochs=20,
    callbacks=[tensorboard_callback]
)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<tensorflow.python.keras.callbacks.History at 0x7f6e3027f400>

## Baseline的test acc和推断速度

In [14]:
# 此处我们可以看到Baseline的test acc和inference速度为6ms/step
_, baseline_model_accuracy = model.evaluate(test_batches, verbose=1)
print('Baseline test accuracy: ', baseline_model_accuracy)

keras_file = './test.h5'
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
print('Saved baseline model to: ', keras_file)

Baseline test accuracy:  0.749720573425293
Saved baseline model to:  ./test.h5


In [15]:
%tensorboard --logdir ./logs

Reusing TensorBoard on port 6011 (pid 14899), started 0:09:31 ago. (Use '!kill 14899' to kill it.)