In [1]:
# pip install -q tensorflow-model-optimization

In [2]:
import tensorflow as tf
from tensorflow import keras

import numpy as np
import tempfile
import zipfile
import os

Train a tf.keras model for MNIST without clustering

In [3]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(28, 28)),
    keras.layers.Reshape(target_shape=(28, 28, 1)),
    keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7fb4a218ba00>

Evaluate the baseline model and save it for later usage

In [4]:
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
tf.keras.models.save_model(model, keras_file, include_optimizer=False)

Baseline test accuracy: 0.98089998960495
Saving model to:  /var/folders/gb/4t3qb3rn74q3j_03hdn28vwr0000gn/T/tmpnqrva_g4.h5


Define the model and apply the clustering API

In [5]:
import tensorflow_model_optimization as tfmot

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 16,
  'cluster_centroids_init': CentroidInitialization.LINEAR
}

# Cluster a whole model
clustered_model = cluster_weights(model, **clustering_params)

# Use smaller learning rate for fine-tuning clustered model
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)

clustered_model.compile(
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

clustered_model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 cluster_reshape (ClusterWei  (None, 28, 28, 1)        0         
 ghts)                                                           
                                                                 
 cluster_conv2d (ClusterWeig  (None, 26, 26, 12)       244       
 hts)                                                            
                                                                 
 cluster_max_pooling2d (Clus  (None, 13, 13, 12)       0         
 terWeights)                                                     
                                                                 
 cluster_flatten (ClusterWei  (None, 2028)             0         
 ghts)                                                           
                                                                 
 cluster_dense (ClusterWeigh  (None, 10)               4

In [6]:
# Fine-tune model
clustered_model.fit(
  train_images,
  train_labels,
  batch_size=500,
  epochs=1,
  validation_split=0.1)



<keras.callbacks.History at 0x7fb46e00d040>