In [4]:
import tensorflow as tf
import numpy as np
import os
import tempfile
from tensorflow import keras
from keras import Sequential
from keras.layers import InputLayer, Reshape, Conv2D, MaxPooling2D, Flatten, Dense
import pathlib

In [5]:
"""This is done so to show effects of Quantization Aware training"""
"""Start of non-QAT"""
# Load mnist dataset
mnist = keras.datasets.mnist
(train_img, train_label),(test_img, test_label) = mnist.load_data()

In [6]:
# Normalize data/img
train_img = train_img/255.0
test_img = test_img/255.0

# Setup model architecture
model = Sequential([
    InputLayer(input_shape=(28,28)),
    Reshape(target_shape=(28,28,1)),
    Conv2D(filters=12, kernel_size=(3,3),activation='relu'),
    MaxPooling2D(pool_size=(2,2)),
    Flatten(),
    Dense(10)
    ])

In [7]:
# Train model
OPT = 'adam'
LOSS = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
METRICS = ['accuracy']
model.compile(optimizer=OPT, loss=LOSS, metrics=METRICS)

model.fit(train_img, train_label, epochs=1, validation_split=0.1)

"""End of non-QAT"""



'End of non-QAT'

In [8]:
"""QAT optimization"""
import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model
qat_model = quantize_model(model)

# recompile qat model
qat_model.compile(optimizer=OPT,loss=LOSS,metrics=METRICS)

qat_model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
quantize_layer (QuantizeLaye (None, 28, 28)            3         
_________________________________________________________________
quant_reshape (QuantizeWrapp (None, 28, 28, 1)         1         
_________________________________________________________________
quant_conv2d (QuantizeWrappe (None, 26, 26, 12)        147       
_________________________________________________________________
quant_max_pooling2d (Quantiz (None, 13, 13, 12)        1         
_________________________________________________________________
quant_flatten (QuantizeWrapp (None, 2028)              1         
_________________________________________________________________
quant_dense (QuantizeWrapper (None, 10)                20295     
Total params: 20,448
Trainable params: 20,410
Non-trainable params: 38
___________________________________________________

In [9]:
# train and fit against baseline
train_img_subset = train_img[0:1000]
train_label_subset = train_label[0:1000]

qat_model.fit(train_img_subset, train_label_subset, batch_size=500, epochs=1, validation_split=0.1)

_,baseline_model_acc = model.evaluate(test_img, test_label, verbose=0)
_,qat_model_acc = qat_model.evaluate(test_img, test_label, verbose=0)

print('baseline accuracy is: {}'.format(baseline_model_acc))
print('QAT model accuracy is: {}'.format(qat_model_acc))

# Note: qat_model is not yet quantized!

baseline accuracy is: 0.9621999859809875
QAT model accuracy is: 0.9616000056266785


In [10]:
# Quantize and convert to TFLite
tfl_converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
tfl_converter.optimizations = [tf.lite.Optimize.DEFAULT]

tflite_quantized_model = tfl_converter.convert()
tflite_q_model = pathlib.Path('tflite_quant_model.tflite')
tflite_q_model.write_bytes(tflite_quantized_model)



INFO:tensorflow:Assets written to: C:\Users\User\AppData\Local\Temp\tmppb0ujr3k\assets


INFO:tensorflow:Assets written to: C:\Users\User\AppData\Local\Temp\tmppb0ujr3k\assets


24592

In [11]:
# Test TFLite accuracy | Evaluate and compare
# Define fn to evaluate models
def evaluate_model(interpreter):
    input_index = interpreter.get_input_details()[0]['index']
    output_index = interpreter.get_output_details()[0]['index']

    # predict every img in the test dataset
    prediction = []
    for i, test_im in enumerate(test_img):
        if i %1000 == 0:
            print('evaluated on {} reslts so far'.format(i))
        
        # convert to float32 to match model's input data format
        test_im = np.expand_dims(test_im, axis=0).astype(np.float32)
        interpreter.set_tensor(input_index, test_im)

        # Inference
        interpreter.invoke()

        # fin digit with max probability
        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        prediction.append(digit)

    print('\n')

    # compare results with actual labels and compute acc
    prediction = np.array(prediction)
    accuracy = (prediction == test_label).mean()
    return accuracy

In [12]:
interpreter = tf.lite.Interpreter(model_content=tflite_quantized_model)
interpreter.allocate_tensors()

test_acc = evaluate_model(interpreter)

print('QAT TFLite acc: {}'.format(test_acc))
print('QAT TF acc: {}'.format(qat_model_acc))


evaluated on 0 reslts so far
evaluated on 1000 reslts so far
evaluated on 2000 reslts so far
evaluated on 3000 reslts so far
evaluated on 4000 reslts so far
evaluated on 5000 reslts so far
evaluated on 6000 reslts so far
evaluated on 7000 reslts so far
evaluated on 8000 reslts so far
evaluated on 9000 reslts so far


QAT TFLite acc: 0.9617
QAT TF acc: 0.9616000056266785


In [16]:
# Compress model by 4x
float_converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_float_model = float_converter.convert()

# find model size
_,float_file = tempfile.mkstemp('.tflite')
_,quant_file = tempfile.mkstemp('.tflite')

with open(quant_file, 'wb') as f:
    f.write(tflite_quantized_model)
with open(float_file, 'wb') as f:
    f.write(tflite_float_model)

print("float_model size: {} MB".format(os.path.getsize(float_file)/float(2**20)))
print("quantized_model size: {} MB".format(os.path.getsize(quant_file)/float(2**20)))

INFO:tensorflow:Assets written to: C:\Users\User\AppData\Local\Temp\tmpwgq5mntc\assets


INFO:tensorflow:Assets written to: C:\Users\User\AppData\Local\Temp\tmpwgq5mntc\assets


float_model size: 0.08058547973632812 MB
quantized_model size: 0.0234527587890625 MB


In [14]:
# Save model(s)
# using savedModel
model.save('saved_model/baseline_q_model')
qat_model.save('saved_model/tf_qat_model')
# using hdf5
model.save('baseline_model.h5')
qat_model.save('tf_q_model.h5')

INFO:tensorflow:Assets written to: saved_model/baseline_q_model\assets


INFO:tensorflow:Assets written to: saved_model/baseline_q_model\assets


INFO:tensorflow:Assets written to: saved_model/tf_qat_model\assets


INFO:tensorflow:Assets written to: saved_model/tf_qat_model\assets


In [15]:
# Save tflite model(s)
tflite_model_file = pathlib.Path('tflite_float_model.tflite')
tflite_model_file.write_bytes(tflite_float_model)

84500