In [1]:
# pip install -q tensorflow-model-optimization

In [2]:
import tensorflow as tf
from tensorflow import keras

import numpy as np
import tempfile
import zipfile
import os

Train a tf.keras model for MNIST without clustering

In [3]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(28, 28)),
    keras.layers.Reshape(target_shape=(28, 28, 1)),
    keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f842f020a00>

Evaluate the baseline model and save it for later usage

In [4]:
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
tf.keras.models.save_model(model, keras_file, include_optimizer=False)

Baseline test accuracy: 0.9821000099182129
Saving model to:  /var/folders/gb/4t3qb3rn74q3j_03hdn28vwr0000gn/T/tmpulfvyker.h5


Define the model and apply the clustering API

In [5]:
import tensorflow_model_optimization as tfmot

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 16,
  'cluster_centroids_init': CentroidInitialization.LINEAR
}

# Cluster a whole model
clustered_model = cluster_weights(model, **clustering_params)

# Use smaller learning rate for fine-tuning clustered model
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)

clustered_model.compile(
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

clustered_model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 cluster_reshape (ClusterWei  (None, 28, 28, 1)        0         
 ghts)                                                           
                                                                 
 cluster_conv2d (ClusterWeig  (None, 26, 26, 12)       244       
 hts)                                                            
                                                                 
 cluster_max_pooling2d (Clus  (None, 13, 13, 12)       0         
 terWeights)                                                     
                                                                 
 cluster_flatten (ClusterWei  (None, 2028)             0         
 ghts)                                                           
                                                                 
 cluster_dense (ClusterWeigh  (None, 10)               4

In [6]:
# Fine-tune model
clustered_model.fit(
  train_images,
  train_labels,
  batch_size=500,
  epochs=1,
  validation_split=0.1)



<keras.callbacks.History at 0x7f840fc304f0>

In [7]:
# Minimal Test Accuracy
_, clustered_model_accuracy = clustered_model.evaluate(
  test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Clustered test accuracy:', clustered_model_accuracy)

Baseline test accuracy: 0.9821000099182129
Clustered test accuracy: 0.9779000282287598


**Create 6x smaller models from clustering**

In [8]:
# Create a compressible model for TensorFlow
final_model = tfmot.clustering.keras.strip_clustering(clustered_model)

_, clustered_keras_file = tempfile.mkstemp('.h5')
print('Saving clustered model to: ', clustered_keras_file)
tf.keras.models.save_model(final_model, clustered_keras_file, 
                           include_optimizer=False)

Saving clustered model to:  /var/folders/gb/4t3qb3rn74q3j_03hdn28vwr0000gn/T/tmppqptn2gf.h5


In [9]:
# compressable models for TFLite
clustered_tflite_file = '/tmp/clustered_mnist.tflite'
converter = tf.lite.TFLiteConverter.from_keras_model(final_model)
tflite_clustered_model = converter.convert()
with open(clustered_tflite_file, 'wb') as f:
    f.write(tflite_clustered_model)
print('Saved clustered TFLite model to:', clustered_tflite_file)

INFO:tensorflow:Assets written to: /var/folders/gb/4t3qb3rn74q3j_03hdn28vwr0000gn/T/tmp840768to/assets




Saved clustered TFLite model to: /tmp/clustered_mnist.tflite


In [10]:
# Define a helper function to actually compress the models via gzip and measure the zipped size
def get_gzipped_model_size(file):
    # It returns the size of the gzipped model in bytes.
    import os
    import zipfile

    _, zipped_file = tempfile.mkstemp('.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(file)

    return os.path.getsize(zipped_file)

In [11]:
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped clustered Keras model: %.2f bytes" % (get_gzipped_model_size(clustered_keras_file)))
print("Size of gzipped clustered TFlite model: %.2f bytes" % (get_gzipped_model_size(clustered_tflite_file)))

Size of gzipped baseline Keras model: 78206.00 bytes
Size of gzipped clustered Keras model: 12911.00 bytes
Size of gzipped clustered TFlite model: 12488.00 bytes


Create an 8x smaller TFLite model from combining weight clustering and post-training quantization

In [12]:
converter = tf.lite.TFLiteConverter.from_keras_model(final_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()

_, quantized_and_clustered_tflite_file = tempfile.mkstemp('.tflite')

with open(quantized_and_clustered_tflite_file, 'wb') as f:
    f.write(tflite_quant_model)

print('Saved quantized and clustered TFLite model to:', quantized_and_clustered_tflite_file)
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped clustered and quantized TFlite model: %.2f bytes" % (get_gzipped_model_size(quantized_and_clustered_tflite_file)))

INFO:tensorflow:Assets written to: /var/folders/gb/4t3qb3rn74q3j_03hdn28vwr0000gn/T/tmpcq604uh5/assets


INFO:tensorflow:Assets written to: /var/folders/gb/4t3qb3rn74q3j_03hdn28vwr0000gn/T/tmpcq604uh5/assets


Saved quantized and clustered TFLite model to: /var/folders/gb/4t3qb3rn74q3j_03hdn28vwr0000gn/T/tmp3tu_8482.tflite
Size of gzipped baseline Keras model: 78206.00 bytes
Size of gzipped clustered and quantized TFlite model: 9655.00 bytes


See the persistence of accuracy from TF to TFLite

In [13]:
# Define a helper function to evaluate the TFLite model on the test dataset
def eval_model(interpreter):
    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]

    # Run predictions on every image in the "test" dataset.
    prediction_digits = []
    for i, test_image in enumerate(test_images):
        if i % 1000 == 0:
            print('Evaluated on {n} results so far.'.format(n=i))
        # Pre-processing: add batch dimension and convert to float32 to match with
        # the model's input data format.
        test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
        interpreter.set_tensor(input_index, test_image)

        # Run inference.
        interpreter.invoke()

        # Post-processing: remove batch dimension and find the digit with highest
        # probability.
        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        prediction_digits.append(digit)

    print('\n')
    # Compare prediction results with ground truth labels to calculate accuracy.
    prediction_digits = np.array(prediction_digits)
    accuracy = (prediction_digits == test_labels).mean()
    return accuracy