In [None]:
import logging
import os
import tempfile
import zipfile

import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow import keras

logging.getLogger("tensorflow").setLevel(logging.DEBUG)

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

In [None]:
# Define the model architecture.
model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(28, 28)),
    keras.layers.Reshape(target_shape=(28, 28, 1)),
    keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
])
model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)
model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=1
)
model.summary()

In [None]:
# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 1
validation_split = 0.1 # 10% of training set will be used for validation set. 

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                        final_sparsity=0.80,
                                                        begin_step=0,
                                                        end_step=end_step)
}
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)

# pruned_model.compile(
#     optimizer='adam',
#     loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
#     metrics=['accuracy']
# )
# callbacks = [
#     tfmot.sparsity.keras.UpdatePruningStep()
# ]
# pruned_model.fit(
#     train_images, train_labels,
#     batch_size=batch_size, epochs=epochs, 
#     validation_split=validation_split,
#     callbacks=callbacks
# )
model_for_export = tfmot.sparsity.keras.strip_pruning(pruned_model)
model_for_export.summary()

In [None]:
clustering_params = {
  'number_of_clusters': 16,
  'cluster_centroids_init': tfmot.clustering.keras.CentroidInitialization.LINEAR
}

# Cluster a whole model
clustered_model = tfmot.clustering.keras.cluster_weights(model_for_export, **clustering_params)

# # Use smaller learning rate for fine-tuning clustered model
# clustered_model.compile(
#   loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
#   optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
#   metrics=['accuracy'])

# # Fine-tune model
# clustered_model.fit(
#   train_images,
#   train_labels,
#   batch_size=500,
#   epochs=1,
#   validation_split=0.1)

clustered_model.summary()

final_model = tfmot.clustering.keras.strip_clustering(clustered_model)
final_model.summary()

In [None]:
q_aware_model = tfmot.quantization.keras.quantize_model(final_model)
# q_aware_model.compile(optimizer='adam',
#               loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
#               metrics=['accuracy'])
# q_aware_model.fit(
#   train_images,
#   train_labels,
#   batch_size=500,
#   epochs=1,
#   validation_split=0.1)
q_aware_model.summary()

In [None]:
# 为TFLite后端创建量化模型,获得一个具有int8权重和uint8激活的实际量化模型
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_tflite_model = converter.convert()


In [None]:
# 使用Python TensorFlow Lite解释器运行TensorFlow Lite模型。
interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()

# 测试单例
num = 3
test_image = np.expand_dims(test_images[num], axis=0).astype(np.float32)

input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]

interpreter.set_tensor(input_index, test_image)
interpreter.invoke()
predictions = interpreter.get_tensor(output_index)

import matplotlib.pylab as plt
plt.imshow(test_images[num])
template = "True:{true}, predicted:{predict}"
_ = plt.title(template.format(true= str(test_labels[num]),
                              predict=str(np.argmax(predictions[0]))))
plt.grid(False)