In [None]:
import os
import tempfile

import tensorflow as tf
import numpy as np

from tensorflow import keras
import tensorflow_model_optimization as tfmot

import argparse
from tensorflow.python.keras.callbacks import Callback
from tensorflow.python.lib.io import file_io
import json
import zipfile
import pathlib


################### args ##########################################
parser = argparse.ArgumentParser()
parser.add_argument('--learning_rate', required=False, type=float, default=0.001)
parser.add_argument('--dropout_rate', required=False, type=float, default=0.3)  
parser.add_argument('--model_path', required=False, default='/result/saved_model',type = str)  
parser.add_argument('--model_version', required=False, default='1',type = str)
parser.add_argument('--model_version2', required=False, default='1',type = str)
args = parser.parse_args()


################### tflite functions ###########################
def devaluate_model(interpreter):
    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]

    # Run predictions on every image in the "test" dataset.
    prediction_digits = []
    for test_image in test_images:
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
        test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
        interpreter.set_tensor(input_index, test_image)

    # Run inference.
        interpreter.invoke()

        # Post-processing: remove batch dimension and find the digit with highest
        # probability.
        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        prediction_digits.append(digit)

  # Compare prediction results with ground truth labels to calculate accuracy.
    accurate_count = 0
    for index in range(len(prediction_digits)):
        if prediction_digits[index] == test_labels[index]:
            accurate_count += 1
    accuracy = accurate_count * 1.0 / len(prediction_digits)

    return accuracy

def get_gzipped_model_size(file):
  # Returns size of gzipped model, in bytes.
    import os
    import zipfile

    _, zipped_file = tempfile.mkstemp('.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(file)

    return os.path.getsize(zipped_file)
###########################################################################

# model load(base model)
model = keras.models.load_model(args.model_path+args.model_version)


converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
dynamic_post_quan_tflite_model = converter.convert()

interpreter_dynamic_post_q = tf.lite.Interpreter(model_content=dynamic_post_quan_tflite_model)
input_type = interpreter_dynamic_post_q.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter_dynamic_post_q.get_output_details()[0]['dtype']
print('output: ', output_type)

dynamic_post_quan_tflite_file = pathlib.Path(args.model_path+args.model_version2)
dynamic_post_quan_tflite_file.write_bytes(dynamic_post_quan_tflite_model)

interpreter_dynamic_post_q.allocate_tensors()
dynamic_post_q_accuracy = devaluate_model(interpreter_dynamic_post_q)

print('Dynamic Post-Quant test accuracy:', dynamic_post_q_accuracy)
print('Dynamic Post-Quant size:', get_gzipped_model_size(dynamic_post_quan_tflite_file))