In [None]:
import os
import tempfile

import tensorflow as tf
import numpy as np

from tensorflow import keras
import tensorflow_model_optimization as tfmot

import argparse
from tensorflow.python.keras.callbacks import Callback
from tensorflow.python.lib.io import file_io
import json
import zipfile
import pathlib


################### args ##########################################
parser = argparse.ArgumentParser()
parser.add_argument('--learning_rate', required=False, type=float, default=0.001)
parser.add_argument('--dropout_rate', required=False, type=float, default=0.3)  
parser.add_argument('--model_path', required=False, default='/result/saved_model',type = str)  
parser.add_argument('--model_version', required=False, default='1',type = str)
parser.add_argument('--model_version3', required=False, default='1',type = str)
parser.add_argument('--model_version8', required=False, default='1',type = str)
args = parser.parse_args()


################### tflite functions ###########################
def evaluate_model(interpreter):
    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]

    # Run predictions on ever y image in the "test" dataset.
    prediction_digits = []
    for i, test_image in enumerate(test_images):
        if i % 1000 == 0:
            print('Evaluated on {n} results so far.'.format(n=i))
    
        test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
        interpreter.set_tensor(input_index, test_image)

        # Run inference.
        interpreter.invoke()

        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        prediction_digits.append(digit)

    print('\n')
    # Compare prediction results with ground truth labels to calculate accuracy.
    prediction_digits = np.array(prediction_digits)
    accuracy = (prediction_digits == test_labels).mean()
    return accuracy

def get_gzipped_model_size(file):
  # Returns size of gzipped model, in bytes.
    import os
    import zipfile

    _, zipped_file = tempfile.mkstemp('.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(file)

    return os.path.getsize(zipped_file)
###########################################################################

mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images / 255.0
test_images = test_images / 255.0

# model load(base model)
model = keras.models.load_model(args.model_path+args.model_version)

quantize_model = tfmot.quantization.keras.quantize_model
# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])

q_aware_model.summary()

train_images_subset = train_images[0:1000] # out of 60000
train_labels_subset = train_labels[0:1000]

q_aware_model.fit(train_images, train_labels,batch_size=500, epochs=1, validation_split=0.1)

results = q_aware_model.evaluate(test_images,test_labels, batch_size=500)
print('test loss, test acc:', results)
loss = results[0]
accuracy = results[1]
metrics = {
    'metrics': [{
        'name': 'accuracy',
        'numberValue': float(accuracy),
        'format': "PERCENTAGE",
    }, {
        'name': 'loss',
        'numberValue': float(loss),
        'format': "RAW",
    }]
}

with file_io.FileIO('/mlpipeline-metrics.json', 'w') as f:
    json.dump(metrics, f)

_, q_aware_model_accuracy = q_aware_model.evaluate(test_images, test_labels, verbose=0)
tf.keras.models.save_model(q_aware_model, args.model_path+args.model_version3, include_optimizer=False)

print("Quan TFLite")
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qaware_dynamic_post_quan_tflite_model = converter.convert()

quantizationAware_tflite_file =  pathlib.Path(args.model_path+args.model_version8)
quantizationAware_tflite_file.write_bytes(qaware_dynamic_post_quan_tflite_model)

interpreter_q = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter_q.allocate_tensors()

test_accuracy_q = evaluate_model(interpreter_q)

print('Quantization Aware + TFLite test_accuracy:', test_accuracy_q)
print('Quantization Aware + TFLite model size:', get_gzipped_model_size(quantizationAware_tflite_file))