# 量化感知训练和训练后量化的对比

In [1]:
# 本文件主要是以 MNIST数据集 为例 对比 量化感知训练 和 训练后量化 两种量化方式
# 主要分为四部分
# 1 训练一个浮点数模型作为基础标准
# 2 进行量化感知训练
# 3 进行训练后量化
# 4 比较两者结果
# 作者：wzx
# 修改时间：2020.12.14

## 训练浮点数模型

In [2]:
import tempfile
import os
import tensorflow as tf
from tensorflow import keras

import numpy as np

In [3]:
# path setting
# save TFLite Model
if not os.path.exists('./model_saved'):
    os.mkdir('./model_saved')
basic_path = './model_saved'
model_path = os.path.join(basic_path, 'MNIST_model.h5')
QAT_INT8_model_path = os.path.join(basic_path, 'QAT_INT8_MNIST.tflite')
PTQ_INT8_model_path = os.path.join(basic_path, 'PTQ_INT8_MNIST.tflite')

In [4]:
# basic setting
NUM_classes = 10
BATCH_SIZE = 128
EPOCHS = 5
lr = 0.1

# input image dimensions
img_rows, img_cols = 28, 28

In [5]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

In [6]:
# Define the model architecture.
model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(28, 28)),
    keras.layers.Reshape(target_shape=(28, 28, 1)),
    keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
])

# print model
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape (Reshape)            (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 26, 26, 12)        120       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 12)        0         
_________________________________________________________________
flatten (Flatten)            (None, 2028)              0         
_________________________________________________________________
dense (Dense)                (None, 10)                20290     
Total params: 20,410
Trainable params: 20,410
Non-trainable params: 0
_________________________________________________________________


In [7]:
# Train the digit classification model
model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_split=lr,
)

model.save_weights(model_path)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


## 进行量化感知训练

In [8]:
import tensorflow_model_optimization as tfmot

In [9]:
# 插入伪量化节点
quantize_model = tfmot.quantization.keras.quantize_model
# q_aware stands for for quantization aware.
QAT_model = quantize_model(model)
QAT_model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
quantize_layer (QuantizeLaye (None, 28, 28)            3         
_________________________________________________________________
quant_reshape (QuantizeWrapp (None, 28, 28, 1)         1         
_________________________________________________________________
quant_conv2d (QuantizeWrappe (None, 26, 26, 12)        147       
_________________________________________________________________
quant_max_pooling2d (Quantiz (None, 13, 13, 12)        1         
_________________________________________________________________
quant_flatten (QuantizeWrapp (None, 2028)              1         
_________________________________________________________________
quant_dense (QuantizeWrapper (None, 10)                20295     
Total params: 20,448
Trainable params: 20,410
Non-trainable params: 38
___________________________________________________

In [10]:
# 在训练集子集上重新训练，进行微调
train_images_subset = train_images[0:1000]  # out of 60000
train_labels_subset = train_labels[0:1000]

# quantize_model requires a recompile.
QAT_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])

QAT_model.fit(train_images_subset,
              train_labels_subset,
              batch_size=500,
              epochs=1,
              validation_split=lr)



<tensorflow.python.keras.callbacks.History at 0x7f9a8063fa58>

In [11]:
# 评估加入量化感知的结果
_, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)
_, QAT_model_accuracy = QAT_model.evaluate(test_images, test_labels, verbose=0)

print('正常训练的准确率')
print('Baseline test accuracy:', baseline_model_accuracy)
print('量化感知训练(未量化)的准确率')
print('Quant test accuracy:', QAT_model_accuracy)

正常训练的准确率
Baseline test accuracy: 0.9732999801635742
量化感知训练(未量化)的准确率
Quant test accuracy: 0.9746000170707703


In [12]:
# 数据预处理
if tf.keras.backend.image_data_format() == 'channels_first':
    x_train = train_images.reshape(train_images.shape[0], 1, img_rows, img_cols)
    x_test = test_images.reshape(test_images.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = train_images.reshape(train_images.shape[0], img_rows, img_cols, 1)
    x_test = test_images.reshape(test_images.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

batch_input_shape = (1, ) + input_shape

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

# x_train /= 255
# x_test /= 255

print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples


In [13]:
# 为了测定量化阈值，导入一段数据
# Use the first 300 images in the post-training quantization.
def representative_data_gen():
    for i in range(300):
        image = x_train[i].reshape(batch_input_shape)
        yield [image]

In [14]:
# 利用TF Lite后端 创建实际量化模型
# After this, you have an actually quantized model with int8 weights and uint8 activations.
converter = tf.lite.TFLiteConverter.from_keras_model(QAT_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
converter._experimental_new_quantizer = True  # pylint: disable=protected-access

# to enable post-training quantization with the representative dataset
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Set the input and output tensors to int8 (APIs added in r2.3)
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8 


print('Convert TFLite model.')
QAT_INT8_model = converter.convert()

Convert TFLite model.
Instructions for updating:
Simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: /tmp/tmpi00249dc/assets


In [15]:
# 保存量化的 tflite 模型
open(QAT_INT8_model_path, "wb").write(QAT_INT8_model)

24656

In [16]:
# 评估量化感知训练TF Lite模型的准确率
# Define a helper function to evaluate the TF Lite model on the test dataset.
def evaluate_model(interpreter):
    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]

    # Run predictions on every image in the "test" dataset.
    prediction_digits = []
    for i, test_image in enumerate(test_images):
        if i % 1000 == 0:
            print('Evaluated on {n} results so far.'.format(n=i))
        # Pre-processing: add batch dimension and convert to float32 to match with
        # the model's input data format.
        test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
        interpreter.set_tensor(input_index, test_image)

        # Run inference.
        interpreter.invoke()

        # Post-processing: remove batch dimension and find the digit with highest
        # probability.
        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        prediction_digits.append(digit)

    print('\n')
    # Compare prediction results with ground truth labels to calculate accuracy.
    prediction_digits = np.array(prediction_digits)
    accuracy = (prediction_digits == test_labels).mean()
    return accuracy

In [17]:
# You evaluate the quantized model and see that the accuracy from TensorFlow persists to the TFLite backend.
interpreter = tf.lite.Interpreter(model_content=QAT_INT8_model)
interpreter.allocate_tensors()

QAT_test_accuracy = evaluate_model(interpreter)
print('INT8量化后的量化感知训练模型准确度：', QAT_test_accuracy)

Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


INT8量化后的量化感知训练模型准确度： 0.9747


## 训练后量化

In [18]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
converter._experimental_new_quantizer = True  # pylint: disable=protected-access

# to enable post-training quantization with the representative dataset
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Set the input and output tensors to int8 (APIs added in r2.3)
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8 


print('Convert TFLite model.')
PTQ_INT8_model = converter.convert()

Convert TFLite model.
INFO:tensorflow:Assets written to: /tmp/tmpav2w__kj/assets


INFO:tensorflow:Assets written to: /tmp/tmpav2w__kj/assets


In [19]:
# 保存量化的 tflite 模型
open(PTQ_INT8_model_path, "wb").write(PTQ_INT8_model)

24168

In [20]:
# 评估量化感知训练TF Lite模型的准确率
# Define a helper function to evaluate the TF Lite model on the test dataset.
def evaluate_model(interpreter):
    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]

    # Run predictions on every image in the "test" dataset.
    prediction_digits = []
    for i, test_image in enumerate(test_images):
        if i % 1000 == 0:
            print('Evaluated on {n} results so far.'.format(n=i))
        # Pre-processing: add batch dimension and convert to float32 to match with
        # the model's input data format.
        test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
        interpreter.set_tensor(input_index, test_image)

        # Run inference.
        interpreter.invoke()

        # Post-processing: remove batch dimension and find the digit with highest
        # probability.
        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        prediction_digits.append(digit)

    print('\n')
    # Compare prediction results with ground truth labels to calculate accuracy.
    prediction_digits = np.array(prediction_digits)
    accuracy = (prediction_digits == test_labels).mean()
    return accuracy

In [21]:
# You evaluate the quantized model and see that the accuracy from TensorFlow persists to the TFLite backend.
interpreter = tf.lite.Interpreter(model_content=PTQ_INT8_model)
interpreter.allocate_tensors()

PTQ_test_accuracy = evaluate_model(interpreter)
print('INT8量化后的训练后量化模型准确度：', PTQ_test_accuracy)


Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


INT8量化后的训练后量化模型准确度： 0.9733


In [22]:
print('原模型准确率：', baseline_model_accuracy)
print('INT8训练后量化 模型准确度：', PTQ_test_accuracy)
print('INT8量化感知训练 模型准确度：', QAT_test_accuracy)

原模型准确率： 0.9732999801635742
INT8训练后量化 模型准确度： 0.9733
INT8量化感知训练 模型准确度： 0.9747
