In [None]:
import os
import tempfile

import tensorflow as tf
import numpy as np

from tensorflow import keras
import tensorflow_model_optimization as tfmot

import argparse
from tensorflow.python.keras.callbacks import Callback
from tensorflow.python.lib.io import file_io
import json


# 모델 사이즈를 측정하기 위한 함수
def get_gzipped_model_size(file):
  # Returns size of gzipped model, in bytes.
    import os
    import zipfile

    _, zipped_file = tempfile.mkstemp('.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(file)

    return os.path.getsize(zipped_file)

parser = argparse.ArgumentParser()
parser.add_argument('--learning_rate', required=False, type=float, default=0.001)
parser.add_argument('--dropout_rate', required=False, type=float, default=0.3)  
parser.add_argument('--model_path', required=False, default='/result/saved_model',type = str)  
parser.add_argument('--model_version', required=False, default='1',type = str)
parser.add_argument('--model_version2', required=False, default='1',type = str)
args = parser.parse_args()    


prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

batch_size = 128
epochs = 5
validation_split = 0.1

mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()


train_images = train_images.astype(np.float32) / 255.0
test_images = test_images.astype(np.float32) / 255.0

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs


pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.90,
                                                               begin_step=0,
                                                               end_step=end_step)
}

# model load(base model)
model = keras.models.load_model(args.model_path+args.model_version)

model_for_pruning = prune_low_magnitude(model, **pruning_params)

model_for_pruning.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model_for_pruning.summary()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
]
  
model_for_pruning.fit(train_images, train_labels,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

results = model_for_pruning.evaluate(test_images,test_labels, batch_size=128)
print('test loss, test acc:', results)
loss = results[0]
accuracy = results[1]
metrics = {
    'metrics': [{
        'name': 'accuracy',
        'numberValue': float(accuracy),
        'format': "PERCENTAGE",
    }, {
        'name': 'loss',
        'numberValue': float(loss),
        'format': "RAW",
    }]
}

with file_io.FileIO('/mlpipeline-metrics.json', 'w') as f:
    json.dump(metrics, f)

_, model_for_pruning_accuracy = model_for_pruning.evaluate(test_images, test_labels, verbose=0)
print("Pruned model accuracy : ", model_for_pruning_accuracy)
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

tf.keras.models.save_model(model, args.model_path+args.model_version2, include_optimizer=False)

print("Pruned model size: ",get_gzipped_model_size(args.model_path+args.model_version2))
