In [1]:
! pip install -q tensorflow-model-optimization

import tensorflow as tf
import numpy as np
import tempfile
import os
import tensorflow_model_optimization as tfmot

input_dim = 20
output_dim = 20
x_train = np.random.randn(1, input_dim).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=output_dim)

def setup_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(input_dim, input_shape=[input_dim]),
        tf.keras.layers.Flatten()
    ])
    return model

def train_model(model):
    model.compile(
        loss=tf.keras.losses.categorical_crossentropy,
        optimizer='adam',
        metrics=['accuracy']
    )
    model.summary()
    model.fit(x_train, y_train)
    return model

def save_model_weights(model):
    _, pretrained_weights = tempfile.mkstemp('.h5')
    model.save_weights(pretrained_weights)
    return pretrained_weights

def setup_pretrained_weights():
    model= setup_model()
    model = train_model(model)
    pretrained_weights = save_model_weights(model)
    return pretrained_weights

def setup_pretrained_model():
    model = setup_model()
    pretrained_weights = setup_pretrained_weights()
    model.load_weights(pretrained_weights)
    return model

def save_model_file(model):
    _, keras_file = tempfile.mkstemp('.h5') 
    model.save(keras_file, include_optimizer=False)
    return keras_file

def get_gzipped_model_size(model):
    # It returns the size of the gzipped model in bytes.
    import os
    import zipfile

    keras_file = save_model_file(model)

    _, zipped_file = tempfile.mkstemp('.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(keras_file)
    return os.path.getsize(zipped_file)

setup_model()
pretrained_weights = setup_pretrained_weights()

dyld: Library not loaded: /System/Library/Frameworks/CoreFoundation.framework/Versions/A/CoreFoundation
  Referenced from: /Library/Frameworks/Python.framework/Versions/3.6/Resources/Python.app/Contents/MacOS/Python
  Reason: image not found
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_1 (Dense)             (None, 20)                420       
                                                                 
 flatten_1 (Flatten)         (None, 20)                0         
                                                                 
Total params: 420
Trainable params: 420
Non-trainable params: 0
_________________________________________________________________


In [2]:
# Cluster a whole model (sequential and functional)
import tensorflow_model_optimization as tfmot

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
    'number_of_clusters': 3,
    'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS
}

model = setup_model()
model.load_weights(pretrained_weights)

clustered_model = cluster_weights(model, **clustering_params)

clustered_model.summary()

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 cluster_dense_2 (ClusterWei  (None, 20)               823       
 ghts)                                                           
                                                                 
 cluster_flatten_2 (ClusterW  (None, 20)               0         
 eights)                                                         
                                                                 
Total params: 823
Trainable params: 423
Non-trainable params: 400
_________________________________________________________________


In [3]:
# Cluster some layers (sequential and functional models)
# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights)

# Helper function uses `cluster_weights` to make only 
# the Dense layers train with clustering
def apply_clustering_to_dense(layer):
    if isinstance(layer, tf.keras.layers.Dense):
        return cluster_weights(layer, **clustering_params)
    return layer

# Use `tf.keras.models.clone_model` to apply `apply_clustering_to_dense` 
# to the layers of the model.
clustered_model = tf.keras.models.clone_model(
    base_model,
    clone_function=apply_clustering_to_dense,
)

clustered_model.summary()

Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 cluster_dense_3 (ClusterWei  (None, 20)               823       
 ghts)                                                           
                                                                 
 flatten_3 (Flatten)         (None, 20)                0         
                                                                 
Total params: 823
Trainable params: 423
Non-trainable params: 400
_________________________________________________________________


In [4]:
# Cluster custom Keras layer or specify which weights of layer to cluster
class MyDenseLayer(tf.keras.layers.Dense, tfmot.clustering.keras.ClusterableLayer):

    def get_clusterable_weights(self):
        # Cluster kernel and bias. This is just an example, clustering
        # bias usually hurts model accuracy.
        return [('kernel', self.kernel), ('bias', self.bias)]

# Use `cluster_weights` to make the `MyDenseLayer` layer train with clustering as usual.
model_for_clustering = tf.keras.Sequential([
  tfmot.clustering.keras.cluster_weights(MyDenseLayer(20, input_shape=[input_dim]), **clustering_params),
  tf.keras.layers.Flatten()
])

model_for_clustering.summary()

Model: "sequential_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 cluster_my_dense_layer (Clu  (None, 20)               846       
 sterWeights)                                                    
                                                                 
 flatten_4 (Flatten)         (None, 20)                0         
                                                                 
Total params: 846
Trainable params: 426
Non-trainable params: 420
_________________________________________________________________


In [5]:
#Checkpoint and deserialize a clustered model
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights)
clustered_model = cluster_weights(base_model, **clustering_params)

# Save or checkpoint the model.
_, keras_model_file = tempfile.mkstemp('.h5')
clustered_model.save(keras_model_file, include_optimizer=True)

# `cluster_scope` is needed for deserializing HDF5 models.
with tfmot.clustering.keras.cluster_scope():
    loaded_model = tf.keras.models.load_model(keras_model_file)

loaded_model.summary()

Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 cluster_dense_4 (ClusterWei  (None, 20)               823       
 ghts)                                                           
                                                                 
 cluster_flatten_5 (ClusterW  (None, 20)               0         
 eights)                                                         
                                                                 
Total params: 823
Trainable params: 423
Non-trainable params: 400
_________________________________________________________________


In [6]:
# Deployment
# Export model with size compression
model = setup_model()
clustered_model = cluster_weights(model, **clustering_params)

clustered_model.compile(
                        loss=tf.keras.losses.categorical_crossentropy, 
                        optimizer='adam',
                        metrics=['accuracy']
)

clustered_model.fit(x_train, y_train)

final_model = tfmot.clustering.keras.strip_clustering(clustered_model)

print("final model")
final_model.summary()

print("\n")
print("Size of gzipped clustered model without stripping: %.2f bytes" 
      % (get_gzipped_model_size(clustered_model)))
print("Size of gzipped clustered model with stripping: %.2f bytes" 
      % (get_gzipped_model_size(final_model)))

final model
Model: "sequential_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_5 (Dense)             (None, 20)                420       
                                                                 
 flatten_6 (Flatten)         (None, 20)                0         
                                                                 
Total params: 420
Trainable params: 420
Non-trainable params: 0
_________________________________________________________________


Size of gzipped clustered model without stripping: 3551.00 bytes
Size of gzipped clustered model with stripping: 1550.00 bytes
