In [1]:
# google colab setup

from google.colab import drive
drive.mount('/content/drive', force_remount=True)
root_dir = '/content/drive/My Drive/'

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


In [2]:
# imports and load the MNIST data

import numpy as np
import tensorflow as tf

(train_data, train_labels), (eval_data, eval_labels) = tf.keras.datasets.mnist.load_data()

train_data = (train_data.astype('float32') / 255.0).reshape(-1,28,28,1)
eval_data = (eval_data.astype('float32') / 255.0).reshape(-1,28,28,1)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [0]:
# build tf.keras model

def build_keras_model():
    return tf.keras.models.Sequential([

        tf.keras.layers.Conv2D(filters = 32, kernel_size=(3,3), activation=tf.nn.relu, padding='same', input_shape=(28,28,1)),
        tf.keras.layers.BatchNormalization(fused=False),

        tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2),

        tf.keras.layers.Conv2D(filters = 64, kernel_size=(3,3), activation=tf.nn.relu, padding='same'),
        tf.keras.layers.BatchNormalization(fused=False),

        tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2),

        tf.keras.layers.Conv2D(filters = 64, kernel_size=(3,3), activation=tf.nn.relu, padding='same'),
        tf.keras.layers.BatchNormalization(fused=False),

        tf.keras.layers.Flatten(),

        tf.keras.layers.Dense(64, activation=tf.nn.relu),

        tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])


In [4]:
# train the model, quantization-aware training (finetuning) after $[quant_delay] steps

train_batch_size = 50
train_batch_number = train_data.shape[0]
quant_delay_epoch = 1

train_graph = tf.Graph()
train_sess = tf.Session(graph=train_graph)

tf.keras.backend.set_session(train_sess)
with train_graph.as_default():
    train_model = build_keras_model()

    tf.contrib.quantize.create_training_graph(input_graph=train_graph, quant_delay=int(train_batch_number / train_batch_size * quant_delay_epoch))

    train_sess.run(tf.global_variables_initializer())	 

    train_model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    print('\n------ Train ------\n')
    train_model.fit(train_data, train_labels, batch_size = train_batch_size, epochs=quant_delay_epoch * 2)

    print('\n------ Test ------\n')
    loss, acc = train_model.evaluate(eval_data, eval_labels)

    saver = tf.train.Saver()
    saver.save(train_sess, '/content/drive/My Drive/Colab Notebooks/quantization_github/quantization_aware_training_model/checkpoints')


Instructions for updating:
If using Keras pass *_constraint arguments to layers.
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

INFO:tensorflow:Inserting fake quant op activation_Mul_quant after batch_normalization/batchnorm/mul_1
INFO:tensorflow:Inserting fake quant op activation_AddV2_quant after batch_normalization/batchnorm/add_1
INFO:tensorflow:Inserting fake quant op activation_Mul_quant after batch_normalization_1/batchnorm/mul_1
INFO:tensorflow:Inserting fake quant op activation_AddV2_quant after batch_normalization_1/batchnorm/add_1
INFO:tensorflow:Inserting fake quant op activation_Mul_quant after batch_normalization_2/batchnorm/mul_1
INFO:tensorflow:Inserting fake

In [5]:
# save the frozen graph

eval_graph = tf.Graph()
eval_sess = tf.Session(graph=eval_graph)

tf.keras.backend.set_session(eval_sess)

with eval_graph.as_default():
	tf.keras.backend.set_learning_phase(0)
	eval_model = build_keras_model()
	tf.contrib.quantize.create_eval_graph(input_graph=eval_graph)
	eval_graph_def = eval_graph.as_graph_def()
	saver = tf.train.Saver()
	saver.restore(eval_sess, '/content/drive/My Drive/Colab Notebooks/quantization_github/quantization_aware_training_model/checkpoints')
    
	frozen_graph_def = tf.graph_util.convert_variables_to_constants(
		eval_sess,
		eval_graph_def,
		[eval_model.output.op.name]
	)

	with open('/content/drive/My Drive/Colab Notebooks/quantization_github/quantization_aware_training_model/frozen_graph.pb', 'wb') as f:
		f.write(frozen_graph_def.SerializeToString())

INFO:tensorflow:Inserting fake quant op activation_Mul_quant after batch_normalization/batchnorm/mul_1
INFO:tensorflow:Inserting fake quant op activation_AddV2_quant after batch_normalization/batchnorm/add_1
INFO:tensorflow:Inserting fake quant op activation_Mul_quant after batch_normalization_1/batchnorm/mul_1
INFO:tensorflow:Inserting fake quant op activation_AddV2_quant after batch_normalization_1/batchnorm/add_1
INFO:tensorflow:Inserting fake quant op activation_Mul_quant after batch_normalization_2/batchnorm/mul_1
INFO:tensorflow:Inserting fake quant op activation_AddV2_quant after batch_normalization_2/batchnorm/add_1
INFO:tensorflow:Restoring parameters from /content/drive/My Drive/Colab Notebooks/quantization_github/quantization_aware_training_model/checkpoints
Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
INFO:tensorflow:Froze 54 variables.
INFO:tensorflow:Conve

In [6]:
# convert to quantized tf.lite model

input_max = np.max(train_data)
input_min = np.min(train_data)
converter_std = 255 / (input_max - input_min)
converter_mean = -(input_min * converter_std)

converter = tf.lite.TFLiteConverter.from_frozen_graph('/content/drive/My Drive/Colab Notebooks/quantization_github/quantization_aware_training_model/frozen_graph.pb',
                                                     ['conv2d_input'],
                                                     ['dense_1/Softmax'])
converter.inference_type = tf.uint8
converter.quantized_input_stats = {'conv2d_input':(converter_mean, converter_std)}
#converter.default_ranges_stats = (0,1)
tflite_model = converter.convert()
open('/content/drive/My Drive/Colab Notebooks/quantization_github/quantization_aware_training_model/quantized_model.tflite', 'wb').write(tflite_model)

264048

In [8]:
# load the quantized tf.lite model and test

interpreter = tf.lite.Interpreter(model_path='/content/drive/My Drive/Colab Notebooks/quantization_github/quantization_aware_training_model/quantized_model.tflite')
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

quantize_eval_data = np.array(eval_data * 255, dtype = np.uint8)
acc = 0

for i in range(quantize_eval_data.shape[0]):
	quantize_image = quantize_eval_data[i]
	quantize_image = quantize_image.reshape(1,28,28,1)

	interpreter.set_tensor(input_details[0]['index'], quantize_image)
	interpreter.invoke()
	prediction = interpreter.get_tensor(output_details[0]['index'])

	if (eval_labels[i]) == np.argmax(prediction):
		acc += 1

print('Quantization-aware training (finetuning) accuracy: ' + str(acc / len(eval_data)))

Quantization-aware training (finetuning) accuracy: 0.9886


In [12]:
# check the tensor data type

tensor_details = interpreter.get_tensor_details()

for i in tensor_details:
    print(i['dtype'], i['name'], i['index'])

<class 'numpy.uint8'> batch_normalization/batchnorm/add_1 0
<class 'numpy.uint8'> batch_normalization/batchnorm/mul 1
<class 'numpy.uint8'> batch_normalization/batchnorm/mul_1 2
<class 'numpy.uint8'> batch_normalization/batchnorm/sub 3
<class 'numpy.uint8'> batch_normalization_1/batchnorm/add_1 4
<class 'numpy.uint8'> batch_normalization_1/batchnorm/mul 5
<class 'numpy.uint8'> batch_normalization_1/batchnorm/mul_1 6
<class 'numpy.uint8'> batch_normalization_1/batchnorm/sub 7
<class 'numpy.uint8'> batch_normalization_2/batchnorm/add_1 8
<class 'numpy.uint8'> batch_normalization_2/batchnorm/mul 9
<class 'numpy.uint8'> batch_normalization_2/batchnorm/mul_1 10
<class 'numpy.uint8'> batch_normalization_2/batchnorm/sub 11
<class 'numpy.int32'> conv2d/Conv2D_bias 12
<class 'numpy.uint8'> conv2d/Relu 13
<class 'numpy.uint8'> conv2d/weights_quant/FakeQuantWithMinMaxVars 14
<class 'numpy.int32'> conv2d_1/Conv2D_bias 15
<class 'numpy.uint8'> conv2d_1/Relu 16
<class 'numpy.uint8'> conv2d_1/weights