In [None]:
import os
import tempfile

import tensorflow as tf
import numpy as np

from tensorflow import keras
import tensorflow_model_optimization as tfmot

import argparse
from tensorflow.python.keras.callbacks import Callback
from tensorflow.python.lib.io import file_io
import json
import zipfile

parser = argparse.ArgumentParser()
parser.add_argument('--learning_rate', required=False, type=float, default=0.001)
parser.add_argument('--dropout_rate', required=False, type=float, default=0.3)  
parser.add_argument('--model_path', required=False, default='/result/saved_model',type = str)  
parser.add_argument('--model_version', required=False, default='1',type = str)
parser.add_argument('--model_version3', required=False, default='1',type = str)
parser.add_argument('--model_version8', required=False, default='1',type = str)
args = parser.parse_args()   


# 모델 사이즈를 측정하기 위한 함수
def get_gzipped_model_size(file):
    _, zipped_file = tempfile.mkstemp('.zip')#,dir=args.model_path
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(file)
        
    return os.path.getsize(zipped_file)


mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images.astype(np.float32) / 255.0
test_images = test_images.astype(np.float32) / 255.0

# model load(base model)
model = keras.models.load_model(args.model_path+args.model_version)

quantize_model = tfmot.quantization.keras.quantize_model
# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])

q_aware_model.summary()

#train_images_subset = train_images[0:1000] # out of 60000
#train_labels_subset = train_labels[0:1000]

q_aware_model.fit(train_images, train_labels,batch_size=500, epochs=1, validation_split=0.1)

results = q_aware_model.evaluate(test_images,test_labels, batch_size=500)
print('test loss, test acc:', results)
loss = results[0]
accuracy = results[1]
metrics = {
    'metrics': [{
        'name': 'accuracy',
        'numberValue': float(accuracy),
        'format': "PERCENTAGE",
    }, {
        'name': 'loss',
        'numberValue': float(loss),
        'format': "RAW",
    }]
}

with file_io.FileIO('/mlpipeline-metrics.json', 'w') as f:
    json.dump(metrics, f)

_, q_aware_model_accuracy = q_aware_model.evaluate(test_images, test_labels, verbose=0)
print("Quant test accuracy: ", q_aware_model_accuracy)

tf.keras.models.save_model(q_aware_model, args.model_path+args.model_version3, include_optimizer=False)
print("Quant model size: ",get_gzipped_model_size(args.model_path+args.model_version3))