-
Notifications
You must be signed in to change notification settings - Fork 0
/
post-training_quantization_inception_v3.py
58 lines (47 loc) · 2.29 KB
/
post-training_quantization_inception_v3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""
Post Training Quantization using TensorFlow for VGG-16
======================================================
# The script will generate a full integer quantized
VGG-16 .tflite model
# By Kuen-Wey Lin
# pip3 show tensorflow => Version: 2.5.0
"""
import tensorflow as tf
from tensorflow import keras
import numpy as np
from PIL import Image
image_shape = (299, 299)
data_type = 'float32' #expected type FLOAT32 for input of already-trained float TensorFlow model
# a generator function to load images for calibration
def representative_dataset():
num_calibration_steps = 50 # number of images for calibration
imgs = []
batch_size = 1
for sn in range(num_calibration_steps):
image_data = tf.keras.preprocessing.image.load_img('./test/' + str(sn) + '.jpg', target_size=image_shape)
image_data = tf.keras.preprocessing.image.img_to_array(image_data)
image_data = tf.keras.applications.inception_v3.preprocess_input(image_data)
#image_data = image_data.astype(np.float32)
image_data = np.reshape(image_data, (299, 299, 3))
imgs.append(image_data)
imgs = np.array(imgs)
images = tf.data.Dataset.from_tensor_slices(imgs).batch(1)
for i in images.take(batch_size):
yield [i]
# a Keras image classification model, loaded with weights pre-trained on ImageNet
# You can find the downloaded Keras files in $HOME/.keras
model = tf.keras.applications.inception_v3.InceptionV3(weights="imagenet", input_shape=(299, 299, 3))
# convert a tf.Keras model to a TensorFlow Lite model
converter = tf.lite.TFLiteConverter.from_keras_model(model)
# "DEFAULT" quantizes model weights
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# tf.lite.RepresentativeDataset requires a generator function, so use Python's "yield"
converter.representative_dataset = representative_dataset
# The following three lines are used for full integer quantization (input/output/activation tensors are int8)
# To generate a model with float32 input, remark the following three lines
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8 # or tf.uint8
converter.inference_output_type = tf.int8 # or tf.uint8
tflite_quant_model = converter.convert()
with open('model.tflite', 'wb') as f:
f.write(tflite_quant_model)