# Weight clustering in Keras example

## Overview

Welcome to the end-to-end example for *weight clustering*, part of the TensorFlow Model Optimization Toolkit.

### Other pages

For an introduction to what weight clustering is and to determine if you should use it (including what's supported), see the [overview page](https://www.tensorflow.org/model_optimization/guide/clustering).

To quickly find the APIs you need for your use case (beyond fully clustering a model with 16 clusters), see the [comprehensive guide](https://www.tensorflow.org/model_optimization/guide/clustering/clustering_comprehensive_guide).

### Contents

In the tutorial, you will:

1. Train a `tf.keras` model for the MNIST dataset from scratch.
2. Fine-tune the model by applying the weight clustering API and see the accuracy.
3. Create a 6x smaller TF and TFLite models from clustering.
4. Create a 8x smaller TFLite model from combining weight clustering and post-training quantization.
5. See the persistence of accuracy from TF to TFLite.

## Setup

You can run this Jupyter Notebook in your local [virtualenv](https://www.tensorflow.org/install/pip?lang=python3#2.-create-a-virtual-environment-recommended) or [colab](https://colab.sandbox.google.com/). For details of setting up dependencies, please refer to the [installation guide](https://www.tensorflow.org/model_optimization/guide/install). 

In [2]:
! pip install -q tensorflow-model-optimization

In [3]:
import tensorflow as tf
from tensorflow import keras

import numpy as np
import tempfile
import zipfile
import os

## Train a tf.keras model for MNIST without clustering

In [4]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(28, 28)),
    keras.layers.Reshape(target_shape=(28, 28, 1)),
    keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz



    8192/11490434 [..............................] - ETA: 0s




2023-05-26 11:15:37.056973: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:268] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


Epoch 1/10



   1/1688 [..............................] - ETA: 14:18 - loss: 2.3417 - accuracy: 0.0625


  12/1688 [..............................] - ETA: 7s - loss: 2.2253 - accuracy: 0.2760   


  25/1688 [..............................] - ETA: 7s - loss: 2.0906 - accuracy: 0.4475


  38/1688 [..............................] - ETA: 6s - loss: 1.9241 - accuracy: 0.5469


  51/1688 [..............................] - ETA: 6s - loss: 1.7590 - accuracy: 0.6005


  64/1688 [>.............................] - ETA: 6s - loss: 1.5882 - accuracy: 0.6450


  77/1688 [>.............................] - ETA: 6s - loss: 1.4501 - accuracy: 0.6676


  90/1688 [>.............................] - ETA: 6s - loss: 1.3309 - accuracy: 0.6913


 103/1688 [>.............................] - ETA: 6s - loss: 1.2217 - accuracy: 0.7127


 117/1688 [=>............................] - ETA: 6s - loss: 1.1371 - accuracy: 0.7281


 131/1688 [=>............................] - ETA: 6s - loss: 1.0631 - accuracy: 0.7438


 145/1688 [=>............................] - ETA: 6s - loss: 1.0012 - accuracy: 0.7578


 159/1688 [=>............................] - ETA: 6s - loss: 0.9481 - accuracy: 0.7687


 173/1688 [==>...........................] - ETA: 5s - loss: 0.9036 - accuracy: 0.7794


 187/1688 [==>...........................] - ETA: 5s - loss: 0.8588 - accuracy: 0.7896


 201/1688 [==>...........................] - ETA: 5s - loss: 0.8285 - accuracy: 0.7948


 215/1688 [==>...........................] - ETA: 5s - loss: 0.8002 - accuracy: 0.8004


 229/1688 [===>..........................] - ETA: 5s - loss: 0.7718 - accuracy: 0.8064


 243/1688 [===>..........................] - ETA: 5s - loss: 0.7474 - accuracy: 0.8117


 257/1688 [===>..........................] - ETA: 5s - loss: 0.7232 - accuracy: 0.8164


 271/1688 [===>..........................] - ETA: 5s - loss: 0.7042 - accuracy: 0.8202


 284/1688 [====>.........................] - ETA: 5s - loss: 0.6871 - accuracy: 0.8241


 298/1688 [====>.........................] - ETA: 5s - loss: 0.6694 - accuracy: 0.8288


 312/1688 [====>.........................] - ETA: 5s - loss: 0.6542 - accuracy: 0.8317


 326/1688 [====>.........................] - ETA: 5s - loss: 0.6396 - accuracy: 0.8354


 340/1688 [=====>........................] - ETA: 5s - loss: 0.6272 - accuracy: 0.8384


 354/1688 [=====>........................] - ETA: 5s - loss: 0.6143 - accuracy: 0.8414


 368/1688 [=====>........................] - ETA: 5s - loss: 0.6008 - accuracy: 0.8446


 382/1688 [=====>........................] - ETA: 5s - loss: 0.5889 - accuracy: 0.8477
































































































































































































Epoch 2/10



   1/1688 [..............................] - ETA: 8s - loss: 0.2967 - accuracy: 0.9062


  14/1688 [..............................] - ETA: 6s - loss: 0.1691 - accuracy: 0.9420


  27/1688 [..............................] - ETA: 6s - loss: 0.1497 - accuracy: 0.9502


  40/1688 [..............................] - ETA: 6s - loss: 0.1307 - accuracy: 0.9586


  53/1688 [..............................] - ETA: 6s - loss: 0.1364 - accuracy: 0.9587


  66/1688 [>.............................] - ETA: 6s - loss: 0.1442 - accuracy: 0.9555


  79/1688 [>.............................] - ETA: 6s - loss: 0.1407 - accuracy: 0.9569


  92/1688 [>.............................] - ETA: 6s - loss: 0.1375 - accuracy: 0.9592


 105/1688 [>.............................] - ETA: 6s - loss: 0.1412 - accuracy: 0.9595


 118/1688 [=>............................] - ETA: 6s - loss: 0.1368 - accuracy: 0.9600


 132/1688 [=>............................] - ETA: 6s - loss: 0.1336 - accuracy: 0.9607


 145/1688 [=>............................] - ETA: 6s - loss: 0.1297 - accuracy: 0.9614


 158/1688 [=>............................] - ETA: 6s - loss: 0.1270 - accuracy: 0.9630


 171/1688 [==>...........................] - ETA: 5s - loss: 0.1291 - accuracy: 0.9624


 184/1688 [==>...........................] - ETA: 5s - loss: 0.1299 - accuracy: 0.9625


 197/1688 [==>...........................] - ETA: 5s - loss: 0.1282 - accuracy: 0.9632


 210/1688 [==>...........................] - ETA: 5s - loss: 0.1296 - accuracy: 0.9625


 223/1688 [==>...........................] - ETA: 5s - loss: 0.1304 - accuracy: 0.9624


 236/1688 [===>..........................] - ETA: 5s - loss: 0.1287 - accuracy: 0.9631


 249/1688 [===>..........................] - ETA: 5s - loss: 0.1290 - accuracy: 0.9631


 262/1688 [===>..........................] - ETA: 5s - loss: 0.1305 - accuracy: 0.9633


 276/1688 [===>..........................] - ETA: 5s - loss: 0.1310 - accuracy: 0.9633


 289/1688 [====>.........................] - ETA: 5s - loss: 0.1298 - accuracy: 0.9637


 302/1688 [====>.........................] - ETA: 5s - loss: 0.1286 - accuracy: 0.9639


 315/1688 [====>.........................] - ETA: 5s - loss: 0.1287 - accuracy: 0.9633


 328/1688 [====>.........................] - ETA: 5s - loss: 0.1284 - accuracy: 0.9631


 341/1688 [=====>........................] - ETA: 5s - loss: 0.1268 - accuracy: 0.9635


 354/1688 [=====>........................] - ETA: 5s - loss: 0.1261 - accuracy: 0.9641


 367/1688 [=====>........................] - ETA: 5s - loss: 0.1259 - accuracy: 0.9642


 381/1688 [=====>........................] - ETA: 5s - loss: 0.1248 - accuracy: 0.9646










































































































































































































Epoch 3/10



   1/1688 [..............................] - ETA: 8s - loss: 0.0561 - accuracy: 0.9688


  15/1688 [..............................] - ETA: 6s - loss: 0.1431 - accuracy: 0.9583


  29/1688 [..............................] - ETA: 6s - loss: 0.1056 - accuracy: 0.9720


  42/1688 [..............................] - ETA: 6s - loss: 0.0980 - accuracy: 0.9732


  55/1688 [..............................] - ETA: 6s - loss: 0.0910 - accuracy: 0.9750


  68/1688 [>.............................] - ETA: 6s - loss: 0.0817 - accuracy: 0.9784


  82/1688 [>.............................] - ETA: 6s - loss: 0.0757 - accuracy: 0.9802


  95/1688 [>.............................] - ETA: 6s - loss: 0.0792 - accuracy: 0.9786


 109/1688 [>.............................] - ETA: 6s - loss: 0.0773 - accuracy: 0.9796


 122/1688 [=>............................] - ETA: 6s - loss: 0.0787 - accuracy: 0.9795


 136/1688 [=>............................] - ETA: 5s - loss: 0.0801 - accuracy: 0.9782


 149/1688 [=>............................] - ETA: 5s - loss: 0.0793 - accuracy: 0.9776


 162/1688 [=>............................] - ETA: 5s - loss: 0.0813 - accuracy: 0.9769


 175/1688 [==>...........................] - ETA: 5s - loss: 0.0844 - accuracy: 0.9761


 189/1688 [==>...........................] - ETA: 5s - loss: 0.0827 - accuracy: 0.9764


 202/1688 [==>...........................] - ETA: 5s - loss: 0.0864 - accuracy: 0.9751


 216/1688 [==>...........................] - ETA: 5s - loss: 0.0880 - accuracy: 0.9748


 229/1688 [===>..........................] - ETA: 5s - loss: 0.0866 - accuracy: 0.9752


 242/1688 [===>..........................] - ETA: 5s - loss: 0.0878 - accuracy: 0.9751


 255/1688 [===>..........................] - ETA: 5s - loss: 0.0865 - accuracy: 0.9755


 269/1688 [===>..........................] - ETA: 5s - loss: 0.0855 - accuracy: 0.9756


 282/1688 [====>.........................] - ETA: 5s - loss: 0.0842 - accuracy: 0.9764


 295/1688 [====>.........................] - ETA: 5s - loss: 0.0828 - accuracy: 0.9770


 309/1688 [====>.........................] - ETA: 5s - loss: 0.0842 - accuracy: 0.9769


 323/1688 [====>.........................] - ETA: 5s - loss: 0.0832 - accuracy: 0.9771


 336/1688 [====>.........................] - ETA: 5s - loss: 0.0828 - accuracy: 0.9773


 349/1688 [=====>........................] - ETA: 5s - loss: 0.0825 - accuracy: 0.9773


 362/1688 [=====>........................] - ETA: 5s - loss: 0.0821 - accuracy: 0.9777


 375/1688 [=====>........................] - ETA: 5s - loss: 0.0823 - accuracy: 0.9774


 389/1688 [=====>........................] - ETA: 5s - loss: 0.0819 - accuracy: 0.9775
































































































































































































Epoch 4/10



   1/1688 [..............................] - ETA: 8s - loss: 0.0508 - accuracy: 1.0000


  14/1688 [..............................] - ETA: 6s - loss: 0.0358 - accuracy: 0.9978


  28/1688 [..............................] - ETA: 6s - loss: 0.0524 - accuracy: 0.9866


  42/1688 [..............................] - ETA: 6s - loss: 0.0600 - accuracy: 0.9844


  56/1688 [..............................] - ETA: 6s - loss: 0.0552 - accuracy: 0.9849


  70/1688 [>.............................] - ETA: 6s - loss: 0.0626 - accuracy: 0.9826


  84/1688 [>.............................] - ETA: 6s - loss: 0.0619 - accuracy: 0.9833


  98/1688 [>.............................] - ETA: 6s - loss: 0.0669 - accuracy: 0.9812


 112/1688 [>.............................] - ETA: 5s - loss: 0.0663 - accuracy: 0.9807


 126/1688 [=>............................] - ETA: 5s - loss: 0.0691 - accuracy: 0.9802


 140/1688 [=>............................] - ETA: 5s - loss: 0.0677 - accuracy: 0.9804


 154/1688 [=>............................] - ETA: 5s - loss: 0.0676 - accuracy: 0.9807


 168/1688 [=>............................] - ETA: 5s - loss: 0.0684 - accuracy: 0.9808


 182/1688 [==>...........................] - ETA: 5s - loss: 0.0681 - accuracy: 0.9808


 196/1688 [==>...........................] - ETA: 5s - loss: 0.0683 - accuracy: 0.9809


 210/1688 [==>...........................] - ETA: 5s - loss: 0.0686 - accuracy: 0.9805


 224/1688 [==>...........................] - ETA: 5s - loss: 0.0673 - accuracy: 0.9807


 238/1688 [===>..........................] - ETA: 5s - loss: 0.0656 - accuracy: 0.9810


 252/1688 [===>..........................] - ETA: 5s - loss: 0.0651 - accuracy: 0.9812


 266/1688 [===>..........................] - ETA: 5s - loss: 0.0644 - accuracy: 0.9814


 280/1688 [===>..........................] - ETA: 5s - loss: 0.0656 - accuracy: 0.9810


 294/1688 [====>.........................] - ETA: 5s - loss: 0.0648 - accuracy: 0.9811


 308/1688 [====>.........................] - ETA: 5s - loss: 0.0644 - accuracy: 0.9810


 322/1688 [====>.........................] - ETA: 5s - loss: 0.0642 - accuracy: 0.9812


 336/1688 [====>.........................] - ETA: 5s - loss: 0.0644 - accuracy: 0.9809


 350/1688 [=====>........................] - ETA: 5s - loss: 0.0646 - accuracy: 0.9807


 364/1688 [=====>........................] - ETA: 4s - loss: 0.0638 - accuracy: 0.9811


 378/1688 [=====>........................] - ETA: 4s - loss: 0.0647 - accuracy: 0.9812


 392/1688 [=====>........................] - ETA: 4s - loss: 0.0657 - accuracy: 0.9812




























































































































































































Epoch 5/10



   1/1688 [..............................] - ETA: 7s - loss: 0.0564 - accuracy: 0.9688


  15/1688 [..............................] - ETA: 6s - loss: 0.0394 - accuracy: 0.9896


  29/1688 [..............................] - ETA: 6s - loss: 0.0756 - accuracy: 0.9828


  43/1688 [..............................] - ETA: 6s - loss: 0.0694 - accuracy: 0.9833


  57/1688 [>.............................] - ETA: 6s - loss: 0.0697 - accuracy: 0.9819


  71/1688 [>.............................] - ETA: 6s - loss: 0.0683 - accuracy: 0.9802


  85/1688 [>.............................] - ETA: 5s - loss: 0.0644 - accuracy: 0.9816


  99/1688 [>.............................] - ETA: 5s - loss: 0.0646 - accuracy: 0.9817


 113/1688 [=>............................] - ETA: 5s - loss: 0.0654 - accuracy: 0.9820


 127/1688 [=>............................] - ETA: 5s - loss: 0.0634 - accuracy: 0.9828


 141/1688 [=>............................] - ETA: 5s - loss: 0.0610 - accuracy: 0.9836


 155/1688 [=>............................] - ETA: 5s - loss: 0.0594 - accuracy: 0.9841


 169/1688 [==>...........................] - ETA: 5s - loss: 0.0585 - accuracy: 0.9841


 183/1688 [==>...........................] - ETA: 5s - loss: 0.0581 - accuracy: 0.9843


 196/1688 [==>...........................] - ETA: 5s - loss: 0.0589 - accuracy: 0.9842


 210/1688 [==>...........................] - ETA: 5s - loss: 0.0587 - accuracy: 0.9841


 224/1688 [==>...........................] - ETA: 5s - loss: 0.0590 - accuracy: 0.9840


 238/1688 [===>..........................] - ETA: 5s - loss: 0.0578 - accuracy: 0.9844


 252/1688 [===>..........................] - ETA: 5s - loss: 0.0582 - accuracy: 0.9843


 266/1688 [===>..........................] - ETA: 5s - loss: 0.0583 - accuracy: 0.9839


 280/1688 [===>..........................] - ETA: 5s - loss: 0.0582 - accuracy: 0.9839


 294/1688 [====>.........................] - ETA: 5s - loss: 0.0602 - accuracy: 0.9832


 308/1688 [====>.........................] - ETA: 5s - loss: 0.0618 - accuracy: 0.9827


 322/1688 [====>.........................] - ETA: 5s - loss: 0.0617 - accuracy: 0.9826


 336/1688 [====>.........................] - ETA: 5s - loss: 0.0618 - accuracy: 0.9826


 350/1688 [=====>........................] - ETA: 5s - loss: 0.0617 - accuracy: 0.9828


 364/1688 [=====>........................] - ETA: 4s - loss: 0.0609 - accuracy: 0.9830


 378/1688 [=====>........................] - ETA: 4s - loss: 0.0605 - accuracy: 0.9831


 392/1688 [=====>........................] - ETA: 4s - loss: 0.0609 - accuracy: 0.9831




























































































































































































Epoch 6/10



   1/1688 [..............................] - ETA: 7s - loss: 0.0339 - accuracy: 1.0000


  14/1688 [..............................] - ETA: 6s - loss: 0.0508 - accuracy: 0.9911


  27/1688 [..............................] - ETA: 6s - loss: 0.0631 - accuracy: 0.9861


  41/1688 [..............................] - ETA: 6s - loss: 0.0520 - accuracy: 0.9878


  55/1688 [..............................] - ETA: 6s - loss: 0.0505 - accuracy: 0.9875


  69/1688 [>.............................] - ETA: 6s - loss: 0.0489 - accuracy: 0.9869


  83/1688 [>.............................] - ETA: 6s - loss: 0.0480 - accuracy: 0.9868


  97/1688 [>.............................] - ETA: 6s - loss: 0.0475 - accuracy: 0.9871


 111/1688 [>.............................] - ETA: 6s - loss: 0.0463 - accuracy: 0.9868


 125/1688 [=>............................] - ETA: 5s - loss: 0.0447 - accuracy: 0.9870


 139/1688 [=>............................] - ETA: 5s - loss: 0.0447 - accuracy: 0.9867


 153/1688 [=>............................] - ETA: 5s - loss: 0.0435 - accuracy: 0.9867


 167/1688 [=>............................] - ETA: 5s - loss: 0.0443 - accuracy: 0.9862


 181/1688 [==>...........................] - ETA: 5s - loss: 0.0443 - accuracy: 0.9860


 195/1688 [==>...........................] - ETA: 5s - loss: 0.0430 - accuracy: 0.9865


 209/1688 [==>...........................] - ETA: 5s - loss: 0.0444 - accuracy: 0.9858


 223/1688 [==>...........................] - ETA: 5s - loss: 0.0456 - accuracy: 0.9850


 237/1688 [===>..........................] - ETA: 5s - loss: 0.0454 - accuracy: 0.9852


 251/1688 [===>..........................] - ETA: 5s - loss: 0.0452 - accuracy: 0.9853


 265/1688 [===>..........................] - ETA: 5s - loss: 0.0459 - accuracy: 0.9854


 278/1688 [===>..........................] - ETA: 5s - loss: 0.0454 - accuracy: 0.9856


 291/1688 [====>.........................] - ETA: 5s - loss: 0.0461 - accuracy: 0.9855


 305/1688 [====>.........................] - ETA: 5s - loss: 0.0476 - accuracy: 0.9852


 319/1688 [====>.........................] - ETA: 5s - loss: 0.0480 - accuracy: 0.9845


 333/1688 [====>.........................] - ETA: 5s - loss: 0.0479 - accuracy: 0.9845


 347/1688 [=====>........................] - ETA: 5s - loss: 0.0474 - accuracy: 0.9847


 361/1688 [=====>........................] - ETA: 5s - loss: 0.0482 - accuracy: 0.9843


 375/1688 [=====>........................] - ETA: 4s - loss: 0.0494 - accuracy: 0.9842


 389/1688 [=====>........................] - ETA: 4s - loss: 0.0501 - accuracy: 0.9839




























































































































































































Epoch 7/10



   1/1688 [..............................] - ETA: 7s - loss: 0.0348 - accuracy: 0.9688


  15/1688 [..............................] - ETA: 6s - loss: 0.0294 - accuracy: 0.9917


  29/1688 [..............................] - ETA: 6s - loss: 0.0324 - accuracy: 0.9925


  43/1688 [..............................] - ETA: 6s - loss: 0.0384 - accuracy: 0.9906


  57/1688 [>.............................] - ETA: 6s - loss: 0.0369 - accuracy: 0.9907


  71/1688 [>.............................] - ETA: 6s - loss: 0.0434 - accuracy: 0.9894


  85/1688 [>.............................] - ETA: 6s - loss: 0.0448 - accuracy: 0.9890


  99/1688 [>.............................] - ETA: 5s - loss: 0.0488 - accuracy: 0.9877


 113/1688 [=>............................] - ETA: 5s - loss: 0.0491 - accuracy: 0.9873


 127/1688 [=>............................] - ETA: 5s - loss: 0.0499 - accuracy: 0.9865


 141/1688 [=>............................] - ETA: 5s - loss: 0.0505 - accuracy: 0.9860


 155/1688 [=>............................] - ETA: 5s - loss: 0.0511 - accuracy: 0.9855


 169/1688 [==>...........................] - ETA: 5s - loss: 0.0489 - accuracy: 0.9861


 183/1688 [==>...........................] - ETA: 5s - loss: 0.0483 - accuracy: 0.9862


 197/1688 [==>...........................] - ETA: 5s - loss: 0.0472 - accuracy: 0.9862


 211/1688 [==>...........................] - ETA: 5s - loss: 0.0461 - accuracy: 0.9865


 225/1688 [==>...........................] - ETA: 5s - loss: 0.0458 - accuracy: 0.9865


 239/1688 [===>..........................] - ETA: 5s - loss: 0.0455 - accuracy: 0.9868


 253/1688 [===>..........................] - ETA: 5s - loss: 0.0443 - accuracy: 0.9872


 267/1688 [===>..........................] - ETA: 5s - loss: 0.0448 - accuracy: 0.9871


 281/1688 [===>..........................] - ETA: 5s - loss: 0.0447 - accuracy: 0.9873


 295/1688 [====>.........................] - ETA: 5s - loss: 0.0442 - accuracy: 0.9874


 309/1688 [====>.........................] - ETA: 5s - loss: 0.0445 - accuracy: 0.9870


 323/1688 [====>.........................] - ETA: 5s - loss: 0.0438 - accuracy: 0.9875


 337/1688 [====>.........................] - ETA: 5s - loss: 0.0428 - accuracy: 0.9879


 351/1688 [=====>........................] - ETA: 5s - loss: 0.0424 - accuracy: 0.9879


 365/1688 [=====>........................] - ETA: 4s - loss: 0.0417 - accuracy: 0.9882


 379/1688 [=====>........................] - ETA: 4s - loss: 0.0430 - accuracy: 0.9878


 393/1688 [=====>........................] - ETA: 4s - loss: 0.0439 - accuracy: 0.9876




























































































































































































Epoch 8/10



   1/1688 [..............................] - ETA: 8s - loss: 0.0108 - accuracy: 1.0000


  15/1688 [..............................] - ETA: 6s - loss: 0.0312 - accuracy: 0.9917


  29/1688 [..............................] - ETA: 6s - loss: 0.0338 - accuracy: 0.9914


  43/1688 [..............................] - ETA: 6s - loss: 0.0410 - accuracy: 0.9884


  57/1688 [>.............................] - ETA: 6s - loss: 0.0376 - accuracy: 0.9885


  71/1688 [>.............................] - ETA: 6s - loss: 0.0362 - accuracy: 0.9890


  85/1688 [>.............................] - ETA: 6s - loss: 0.0378 - accuracy: 0.9879


  99/1688 [>.............................] - ETA: 5s - loss: 0.0393 - accuracy: 0.9880


 113/1688 [=>............................] - ETA: 5s - loss: 0.0398 - accuracy: 0.9881


 127/1688 [=>............................] - ETA: 5s - loss: 0.0388 - accuracy: 0.9887


 141/1688 [=>............................] - ETA: 5s - loss: 0.0367 - accuracy: 0.9896


 155/1688 [=>............................] - ETA: 5s - loss: 0.0362 - accuracy: 0.9899


 169/1688 [==>...........................] - ETA: 5s - loss: 0.0370 - accuracy: 0.9895


 183/1688 [==>...........................] - ETA: 5s - loss: 0.0365 - accuracy: 0.9896


 197/1688 [==>...........................] - ETA: 5s - loss: 0.0385 - accuracy: 0.9894


 211/1688 [==>...........................] - ETA: 5s - loss: 0.0390 - accuracy: 0.9893


 225/1688 [==>...........................] - ETA: 5s - loss: 0.0394 - accuracy: 0.9893


 239/1688 [===>..........................] - ETA: 5s - loss: 0.0384 - accuracy: 0.9897


 253/1688 [===>..........................] - ETA: 5s - loss: 0.0376 - accuracy: 0.9899


 266/1688 [===>..........................] - ETA: 5s - loss: 0.0371 - accuracy: 0.9900


 280/1688 [===>..........................] - ETA: 5s - loss: 0.0382 - accuracy: 0.9896


 294/1688 [====>.........................] - ETA: 5s - loss: 0.0381 - accuracy: 0.9896


 308/1688 [====>.........................] - ETA: 5s - loss: 0.0380 - accuracy: 0.9895


 322/1688 [====>.........................] - ETA: 5s - loss: 0.0379 - accuracy: 0.9894


 336/1688 [====>.........................] - ETA: 5s - loss: 0.0386 - accuracy: 0.9891


 350/1688 [=====>........................] - ETA: 5s - loss: 0.0386 - accuracy: 0.9889


 364/1688 [=====>........................] - ETA: 4s - loss: 0.0386 - accuracy: 0.9889


 378/1688 [=====>........................] - ETA: 4s - loss: 0.0384 - accuracy: 0.9889


 392/1688 [=====>........................] - ETA: 4s - loss: 0.0395 - accuracy: 0.9885






























































































































































































Epoch 9/10



   1/1688 [..............................] - ETA: 8s - loss: 0.0068 - accuracy: 1.0000


  15/1688 [..............................] - ETA: 6s - loss: 0.0365 - accuracy: 0.9937


  29/1688 [..............................] - ETA: 6s - loss: 0.0357 - accuracy: 0.9903


  43/1688 [..............................] - ETA: 6s - loss: 0.0332 - accuracy: 0.9920


  57/1688 [>.............................] - ETA: 6s - loss: 0.0326 - accuracy: 0.9918


  71/1688 [>.............................] - ETA: 6s - loss: 0.0415 - accuracy: 0.9894


  85/1688 [>.............................] - ETA: 6s - loss: 0.0393 - accuracy: 0.9897


  99/1688 [>.............................] - ETA: 5s - loss: 0.0377 - accuracy: 0.9896


 113/1688 [=>............................] - ETA: 5s - loss: 0.0377 - accuracy: 0.9900


 127/1688 [=>............................] - ETA: 5s - loss: 0.0388 - accuracy: 0.9892


 141/1688 [=>............................] - ETA: 5s - loss: 0.0368 - accuracy: 0.9898


 155/1688 [=>............................] - ETA: 5s - loss: 0.0351 - accuracy: 0.9901


 169/1688 [==>...........................] - ETA: 5s - loss: 0.0368 - accuracy: 0.9896


 183/1688 [==>...........................] - ETA: 5s - loss: 0.0368 - accuracy: 0.9894


 197/1688 [==>...........................] - ETA: 5s - loss: 0.0365 - accuracy: 0.9895


 211/1688 [==>...........................] - ETA: 5s - loss: 0.0365 - accuracy: 0.9895


 224/1688 [==>...........................] - ETA: 5s - loss: 0.0365 - accuracy: 0.9893


 238/1688 [===>..........................] - ETA: 5s - loss: 0.0378 - accuracy: 0.9890


 251/1688 [===>..........................] - ETA: 5s - loss: 0.0373 - accuracy: 0.9890


 265/1688 [===>..........................] - ETA: 5s - loss: 0.0381 - accuracy: 0.9890


 279/1688 [===>..........................] - ETA: 5s - loss: 0.0376 - accuracy: 0.9890


 293/1688 [====>.........................] - ETA: 5s - loss: 0.0374 - accuracy: 0.9890


 307/1688 [====>.........................] - ETA: 5s - loss: 0.0366 - accuracy: 0.9892


 321/1688 [====>.........................] - ETA: 5s - loss: 0.0363 - accuracy: 0.9894


 335/1688 [====>.........................] - ETA: 5s - loss: 0.0364 - accuracy: 0.9895


 349/1688 [=====>........................] - ETA: 5s - loss: 0.0366 - accuracy: 0.9896


 363/1688 [=====>........................] - ETA: 5s - loss: 0.0373 - accuracy: 0.9896


 377/1688 [=====>........................] - ETA: 4s - loss: 0.0376 - accuracy: 0.9891


 391/1688 [=====>........................] - ETA: 4s - loss: 0.0373 - accuracy: 0.9891




























































































































































































Epoch 10/10



   1/1688 [..............................] - ETA: 7s - loss: 0.0064 - accuracy: 1.0000


  15/1688 [..............................] - ETA: 6s - loss: 0.0246 - accuracy: 0.9917


  29/1688 [..............................] - ETA: 6s - loss: 0.0297 - accuracy: 0.9892


  43/1688 [..............................] - ETA: 6s - loss: 0.0334 - accuracy: 0.9898


  57/1688 [>.............................] - ETA: 6s - loss: 0.0293 - accuracy: 0.9918


  71/1688 [>.............................] - ETA: 6s - loss: 0.0294 - accuracy: 0.9921


  85/1688 [>.............................] - ETA: 5s - loss: 0.0297 - accuracy: 0.9915


  98/1688 [>.............................] - ETA: 5s - loss: 0.0281 - accuracy: 0.9923


 112/1688 [>.............................] - ETA: 5s - loss: 0.0286 - accuracy: 0.9925


 126/1688 [=>............................] - ETA: 5s - loss: 0.0273 - accuracy: 0.9933


 140/1688 [=>............................] - ETA: 5s - loss: 0.0282 - accuracy: 0.9937


 154/1688 [=>............................] - ETA: 5s - loss: 0.0287 - accuracy: 0.9937


 168/1688 [=>............................] - ETA: 5s - loss: 0.0276 - accuracy: 0.9942


 182/1688 [==>...........................] - ETA: 5s - loss: 0.0276 - accuracy: 0.9936


 196/1688 [==>...........................] - ETA: 5s - loss: 0.0291 - accuracy: 0.9931


 210/1688 [==>...........................] - ETA: 5s - loss: 0.0283 - accuracy: 0.9935


 224/1688 [==>...........................] - ETA: 5s - loss: 0.0297 - accuracy: 0.9933


 238/1688 [===>..........................] - ETA: 5s - loss: 0.0299 - accuracy: 0.9933


 252/1688 [===>..........................] - ETA: 5s - loss: 0.0296 - accuracy: 0.9931


 266/1688 [===>..........................] - ETA: 5s - loss: 0.0324 - accuracy: 0.9922


 280/1688 [===>..........................] - ETA: 5s - loss: 0.0319 - accuracy: 0.9924


 294/1688 [====>.........................] - ETA: 5s - loss: 0.0323 - accuracy: 0.9919


 308/1688 [====>.........................] - ETA: 5s - loss: 0.0322 - accuracy: 0.9917


 322/1688 [====>.........................] - ETA: 5s - loss: 0.0323 - accuracy: 0.9917


 336/1688 [====>.........................] - ETA: 5s - loss: 0.0317 - accuracy: 0.9918


 350/1688 [=====>........................] - ETA: 5s - loss: 0.0317 - accuracy: 0.9920


 364/1688 [=====>........................] - ETA: 4s - loss: 0.0314 - accuracy: 0.9921


 378/1688 [=====>........................] - ETA: 4s - loss: 0.0311 - accuracy: 0.9921


 392/1688 [=====>........................] - ETA: 4s - loss: 0.0310 - accuracy: 0.9922




























































































































































































<keras.src.callbacks.History at 0x7fa7f4242100>

### Evaluate the baseline model and save it for later usage

In [5]:
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
tf.keras.models.save_model(model, keras_file, include_optimizer=False)

Baseline test accuracy: 0.9790999889373779
Saving model to:  /tmpfs/tmp/tmp9jq1ksoy.h5


  tf.keras.models.save_model(model, keras_file, include_optimizer=False)


## Fine-tune the pre-trained model with clustering

Apply the `cluster_weights()` API to a whole pre-trained model to demonstrate its effectiveness in reducing the model size after applying zip while keeping decent accuracy. For how best to balance the accuracy and compression rate for your use case, please refer to the per layer example in the [comprehensive guide](https://www.tensorflow.org/model_optimization/guide/clustering/clustering_comprehensive_guide).


### Define the model and apply the clustering API

Before you pass the model to the clustering API, make sure it is trained and shows some acceptable accuracy.

In [6]:
import tensorflow_model_optimization as tfmot

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 16,
  'cluster_centroids_init': CentroidInitialization.LINEAR
}

# Cluster a whole model
clustered_model = cluster_weights(model, **clustering_params)

# Use smaller learning rate for fine-tuning clustered model
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)

clustered_model.compile(
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

clustered_model.summary()

Model: "sequential"


_________________________________________________________________


 Layer (type)                Output Shape              Param #   




 cluster_reshape (ClusterWe  (None, 28, 28, 1)         0         


 ights)                                                          


                                                                 


 cluster_conv2d (ClusterWei  (None, 26, 26, 12)        244       


 ghts)                                                           


                                                                 


 cluster_max_pooling2d (Clu  (None, 13, 13, 12)        0         


 sterWeights)                                                    


                                                                 


 cluster_flatten (ClusterWe  (None, 2028)              0         


 ights)                                                          


                                                                 


 cluster_dense (ClusterWeig  (None, 10)                40586     


 hts)                                                            


                                                                 




Total params: 40830 (239.13 KB)


Trainable params: 20442 (79.85 KB)


Non-trainable params: 20388 (159.28 KB)


_________________________________________________________________


### Fine-tune the model and evaluate the accuracy against baseline

Fine-tune the model with clustering for 1 epoch.

In [7]:
# Fine-tune model
clustered_model.fit(
  train_images,
  train_labels,
  batch_size=500,
  epochs=1,
  validation_split=0.1)


  1/108 [..............................] - ETA: 1:12 - loss: 0.0421 - accuracy: 0.9920


  5/108 [>.............................] - ETA: 1s - loss: 0.0442 - accuracy: 0.9860  


 10/108 [=>............................] - ETA: 1s - loss: 0.0447 - accuracy: 0.9858


 15/108 [===>..........................] - ETA: 1s - loss: 0.0438 - accuracy: 0.9856


 19/108 [====>.........................] - ETA: 1s - loss: 0.0417 - accuracy: 0.9866


 23/108 [=====>........................] - ETA: 1s - loss: 0.0433 - accuracy: 0.9860




































<keras.src.callbacks.History at 0x7fa7f47d8e20>

For this example, there is minimal loss in test accuracy after clustering, compared to the baseline.

In [8]:
_, clustered_model_accuracy = clustered_model.evaluate(
  test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Clustered test accuracy:', clustered_model_accuracy)

Baseline test accuracy: 0.9790999889373779
Clustered test accuracy: 0.9779999852180481


## Create **6x** smaller models from clustering

Both `strip_clustering` and applying a standard compression algorithm (e.g. via gzip) are necessary to see the compression benefits of clustering. 

First, create a compressible model for TensorFlow. Here, `strip_clustering` removes all variables (e.g. `tf.Variable` for storing the cluster centroids and the indices) that clustering only needs during training, which would otherwise add to model size during inference.

In [9]:
final_model = tfmot.clustering.keras.strip_clustering(clustered_model)

_, clustered_keras_file = tempfile.mkstemp('.h5')
print('Saving clustered model to: ', clustered_keras_file)
tf.keras.models.save_model(final_model, clustered_keras_file, 
                           include_optimizer=False)

Saving clustered model to:  /tmpfs/tmp/tmpx4u1ju_5.h5


  tf.keras.models.save_model(final_model, clustered_keras_file,


Then, create compressible models for TFLite. You can convert the clustered model to a format that's runnable on your targeted backend. TensorFlow Lite is an example you can use to deploy to mobile devices.

In [10]:
clustered_tflite_file = '/tmp/clustered_mnist.tflite'
converter = tf.lite.TFLiteConverter.from_keras_model(final_model)
tflite_clustered_model = converter.convert()
with open(clustered_tflite_file, 'wb') as f:
  f.write(tflite_clustered_model)
print('Saved clustered TFLite model to:', clustered_tflite_file)

INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpt21y56f3/assets


INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpt21y56f3/assets


2023-05-26 11:16:50.733209: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
2023-05-26 11:16:50.733245: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.


Saved clustered TFLite model to: /tmp/clustered_mnist.tflite


Define a helper function to actually compress the models via gzip and measure the zipped size.

In [11]:
def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in bytes.
  import os
  import zipfile

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)

Compare and see that the models are **6x** smaller from clustering

In [12]:
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped clustered Keras model: %.2f bytes" % (get_gzipped_model_size(clustered_keras_file)))
print("Size of gzipped clustered TFlite model: %.2f bytes" % (get_gzipped_model_size(clustered_tflite_file)))

Size of gzipped baseline Keras model: 78169.00 bytes
Size of gzipped clustered Keras model: 12649.00 bytes
Size of gzipped clustered TFlite model: 12230.00 bytes


## Create an **8x** smaller TFLite model from combining weight clustering and post-training quantization

You can apply post-training quantization to the clustered model for additional benefits.

In [13]:
converter = tf.lite.TFLiteConverter.from_keras_model(final_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()

_, quantized_and_clustered_tflite_file = tempfile.mkstemp('.tflite')

with open(quantized_and_clustered_tflite_file, 'wb') as f:
  f.write(tflite_quant_model)

print('Saved quantized and clustered TFLite model to:', quantized_and_clustered_tflite_file)
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped clustered and quantized TFlite model: %.2f bytes" % (get_gzipped_model_size(quantized_and_clustered_tflite_file)))

INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpuk3uupzj/assets


INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpuk3uupzj/assets


2023-05-26 11:16:51.350625: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
2023-05-26 11:16:51.350663: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.


Saved quantized and clustered TFLite model to: /tmpfs/tmp/tmpnwe0hp0u.tflite
Size of gzipped baseline Keras model: 78169.00 bytes
Size of gzipped clustered and quantized TFlite model: 9466.00 bytes


## See the persistence of accuracy from TF to TFLite

Define a helper function to evaluate the TFLite model on the test dataset.

In [14]:
def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print('Evaluated on {n} results so far.'.format(n=i))
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

You evaluate the model, which has been clustered and quantized, and then see the accuracy from TensorFlow persists to the TFLite backend.

In [15]:
interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
interpreter.allocate_tensors()

test_accuracy = eval_model(interpreter)

print('Clustered and quantized TFLite test_accuracy:', test_accuracy)
print('Clustered TF test accuracy:', clustered_model_accuracy)

INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


Evaluated on 0 results so far.
Evaluated on 1000 results so far.


Evaluated on 2000 results so far.
Evaluated on 3000 results so far.


Evaluated on 4000 results so far.
Evaluated on 5000 results so far.


Evaluated on 6000 results so far.
Evaluated on 7000 results so far.


Evaluated on 8000 results so far.
Evaluated on 9000 results so far.




Clustered and quantized TFLite test_accuracy: 0.9785
Clustered TF test accuracy: 0.9779999852180481


## Conclusion

In this tutorial, you saw how to create clustered models with the TensorFlow Model Optimization Toolkit API. More specifically, you've been through an end-to-end example for creating an 8x smaller model for MNIST with minimal accuracy difference. We encourage you to try this new capability, which can be particularly important for deployment in resource-constrained environments.
