**Copyright 2021 The TensorFlow Authors.**

In [1]:
# The Ultimate compression Pipeline  ~ Ajay Maheshwari ( LCI2021023 ) under Dr. Mainak Adhikari

# Step 1.  Creating our Base Model  

In [30]:
! pip install -q tensorflow-model-optimization

In [31]:
import tensorflow as tf
import tf_keras as keras

import numpy as np
import tempfile
import zipfile
import os

In [32]:
print(tf.__version__)

2.16.1


In [33]:
def get_gzipped_model_size(model):
  # Save the model to a temporary file
  with tempfile.NamedTemporaryFile(suffix=".h5") as temp_file:  # Adjust suffix based on model format
    model.save(temp_file.name)

    # Create a zip archive and write the temporary file
    _, zipped_file = tempfile.mkstemp('.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
      f.write(temp_file.name)

    return os.path.getsize(zipped_file) / 1000

In [34]:
# MNIST dataset + Model Creation 

mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images / 255.0
test_images  = test_images / 255.0


model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(28, 28)),
    keras.layers.Reshape(target_shape=(28, 28, 1)),
    
    # First Convolutional Block
    keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(filters=32, kernel_size=(4, 4), activation='relu', padding='same'),
    keras.layers.MaxPooling2D(pool_size=(3, 3)),
    
    # Second Convolutional Block
    keras.layers.Conv2D(filters=64, kernel_size=(5, 5), activation='tanh', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(filters=128, kernel_size=(3, 3), activation='tanh', padding='same'),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    
    # Third Convolutional Block
    keras.layers.Conv2D(filters=128, kernel_size=(3, 3), activation='selu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(filters=128, kernel_size=(3, 3), activation='selu', padding='same'),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    
    keras.layers.Flatten(),
    
    # Fully Connected Layers
    keras.layers.Dense(256, activation='relu'),
    keras.layers.Dropout(0.4),  # Dropout for regularization
    
    keras.layers.Dense(128, activation='tanh'),
    keras.layers.Dropout(0.4),  # Dropout for regularization
    
    keras.layers.Dense(64, activation='selu'),
    keras.layers.Dropout(0.4),  # Dropout for regularization
    
    keras.layers.Dense(10)
])

opt = keras.optimizers.Adam(learning_rate=1e-3)



In [35]:
# Training the Model ( 10 Epochs for now )

model.compile(optimizer=opt,
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=3
)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tf_keras.src.callbacks.History at 0x282fef490>

In [36]:

# data_augmentation = keras.Sequential([
#     keras.layers.RandomFlip("horizontal_and_vertical"),
#     keras.layers.RandomRotation(0.2),
#     keras.layers.RandomZoom(0.2),
# #     layers.RandomContrast(0.2),
# #     layers.Lambda(lambda x: tf.image.random_brightness(x, max_delta=0.2)),  # Adding random brightness
# #     layers.Lambda(lambda x: tf.image.random_saturation(x, lower=0.8, upper=1.2)),  # Adding random saturation
# #     layers.Lambda(lambda x: tf.image.random_hue(x, max_delta=0.2)),  # Adding random hue
# #     layers.Lambda(lambda x: tf.image.random_jpeg_quality(x, min_jpeg_quality=80, max_jpeg_quality=100))  # Random JPEG quality
# ])

# # Apply data augmentation to the test images
# augmented_images = data_augmentation(test_images, training=False)

# # Combine the original and augmented test images and labels
# augmented_test_images = np.concatenate([test_images, augmented_images], axis=0)
# augmented_test_labels = np.concatenate([test_labels, test_labels], axis=0)  # Duplicate labels for augmented images

# print('Original test images shape:', test_images.shape)
# print('Augmented test images shape:', augmented_test_images.shape)

### Evaluating the base model 

In [37]:
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Base Model test accuracy:', baseline_model_accuracy)

print("Base model size: ",  get_gzipped_model_size(model) , ' KB' )


Base Model test accuracy: 0.9704999923706055


  saving_api.save_model(


Base model size:  6960.145  KB


In [38]:
# import tensorflow as tf
# import numpy as np
# import matplotlib.pyplot as plt

# # Initialize dictionary to store weights and their counts
# weight_counts = {}

# # Iterate through layers
# for layer in model.layers:
#     if hasattr(layer, 'get_weights'):
#         layer_weights = layer.get_weights()
#         for w in layer_weights:
#             # Flatten weights if necessary
#             w_flat = w.flatten()
#             # Update dictionary
#             for weight in w_flat:
#                 if weight in weight_counts:
#                     weight_counts[weight] += 1
#                 else:
#                     weight_counts[weight] = 1

# # Extract unique weights and counts
# unique_weights = np.array(list(weight_counts.keys()))
# counts = np.array(list(weight_counts.values()))

# # Plot the graph
# plt.figure(figsize=(10, 6))
# plt.bar(unique_weights, counts, width=0.1)
# plt.xlabel('Weights')
# plt.ylabel('Number of Connections')
# plt.title('Weights vs Number of Connections')
# plt.show()


# Step 2. Pruning

In [39]:
# sparsity_values = []
# accuracies = []

# for sparsity in np.arange(0.1, 1.0, 0.1):
#     prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

#     pruning_params = {
#           'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(sparsity, begin_step=0, frequency=100)
#       }

#     callbacks = [
#       tfmot.sparsity.keras.UpdatePruningStep()
#     ]

#     pruned_model = prune_low_magnitude(model, **pruning_params)

#     opt = keras.optimizers.Adam(learning_rate=1e-5)

#     pruned_model.compile(
#       loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
#       optimizer=opt,
#       metrics=['accuracy'])


#     pruned_model.fit(
#       train_images,
#       train_labels,
#       epochs=3,
#       validation_split=0.1,
#       callbacks=callbacks)
    
    
#     _, accuracy = pruned_model.evaluate(test_images, test_labels, verbose=0)
    

#     sparsity_values.append(sparsity)
#     accuracies.append(accuracy)

# import matplotlib.pyplot as plt

# plt.plot(sparsity_values, accuracies, marker='o')
# plt.xlabel('Sparsity')
# plt.ylabel('Accuracy')
# plt.title('Accuracy vs Sparsity')
# plt.grid(True)
# plt.show()


In [40]:
# Gnerally it starts dropping for pruning > 60% connections

import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.7, begin_step=0, frequency=100)
  }

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep()
]

pruned_model = prune_low_magnitude(model, **pruning_params)

opt = keras.optimizers.Adam(learning_rate=1e-5)

pruned_model.compile(
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])



In [41]:
# Re-training Model so that remaining connections learn again 

pruned_model.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1,
  callbacks=callbacks)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tf_keras.src.callbacks.History at 0x2929adf10>

In [42]:
# _, pruned_model_accuracy = pruned_model.evaluate(
#     test_images, test_labels, verbose=0)

# print('Pruned Model test accuracy:', pruned_model_accuracy)


### Checking if actually pruned or not 

In [43]:
def print_model_weights_sparsity(model):
    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            if "kernel" not in weight.name or "centroid" in weight.name:
                continue
            weight_size = weight.numpy().size
            zero_num = np.count_nonzero(weight == 0)
            print(
                f"{weight.name}: {zero_num/weight_size:.2%} sparsity ",
                f"({zero_num}/{weight_size})",
            )


stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

print_model_weights_sparsity(stripped_pruned_model)

conv2d_10/kernel:0: 69.97% sparsity  (403/576)
conv2d_11/kernel:0: 70.00% sparsity  (22938/32768)
conv2d_12/kernel:0: 70.00% sparsity  (35840/51200)
conv2d_13/kernel:0: 70.00% sparsity  (51610/73728)
conv2d_14/kernel:0: 70.00% sparsity  (103219/147456)
conv2d_15/kernel:0: 70.00% sparsity  (103219/147456)
dense_11/kernel:0: 70.00% sparsity  (91750/131072)
dense_12/kernel:0: 70.00% sparsity  (22938/32768)
dense_13/kernel:0: 70.00% sparsity  (5734/8192)
dense_14/kernel:0: 70.00% sparsity  (448/640)


In [44]:
stripped_pruned_model.compile(
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

# stripped_pruned_model.fit(train_images,
#   train_labels,
#   epochs=3,
#   validation_split=0.1,
#   callbacks=callbacks)

# _, pruned_model_accuracy = stripped_pruned_model.evaluate(
#     test_images, test_labels, verbose=0)

# print('Pruned Model test accuracy:', pruned_model_accuracy)



In [45]:
_, pruned_model_accuracy = stripped_pruned_model.evaluate(
    test_images, test_labels, verbose=0)

print('Pruned Model test accuracy:', pruned_model_accuracy)

print("Pruned model size: ",  get_gzipped_model_size(stripped_pruned_model) , ' KB' )



Pruned Model test accuracy: 0.9882000088691711
Pruned model size:  5588.956  KB


In [46]:
# import tensorflow as tf
# import numpy as np
# import matplotlib.pyplot as plt

# # Initialize dictionary to store weights and their counts
# weight_counts = {}

# # Iterate through layers
# for layer in stripped_pruned_model.layers:
#     if hasattr(layer, 'get_weights'):
#         layer_weights = layer.get_weights()
#         for w in layer_weights:
#             # Flatten weights if necessary
#             w_flat = w.flatten()
#             # Update dictionary
#             for weight in w_flat:
#                 if weight in weight_counts:
#                     weight_counts[weight] += 1
#                 else:
#                     weight_counts[weight] = 1

# # Extract unique weights and counts
# unique_weights = np.array(list(weight_counts.keys()))
# counts = np.array(list(weight_counts.values()))

# # Plot the graph
# plt.figure(figsize=(10, 6))
# plt.bar(unique_weights, counts, width=0.1)
# plt.xlabel('Weights')
# plt.ylabel('Number of Connections')
# plt.title('Weights vs Number of Connections')
# plt.show()


In [47]:
# -----------------------------  Step 2 Pruning Done ------------------------------

# Step 3. Weight Clustering

In [48]:
def print_model_weight_clusters(model):
    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            if "kernel" in weight.name:
                unique_count = len(np.unique(weight))
                print(
                    f"{layer.name}/{weight.name}: {unique_count} clusters "
                )

In [49]:
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.clustering.keras.experimental import (
    cluster,
)

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

cluster_weights = cluster.cluster_weights

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
  'preserve_sparsity': True
}

sparsity_clustered_model = cluster_weights(stripped_pruned_model, **clustering_params)

sparsity_clustered_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

print('Train sparsity preserving clustering model:')
sparsity_clustered_model.fit(train_images, train_labels,epochs=2, validation_split=0.1)

Train sparsity preserving clustering model:
Epoch 1/2
Epoch 2/2


<tf_keras.src.callbacks.History at 0x2854b2310>

In [50]:
stripped_clustered_model = tfmot.clustering.keras.strip_clustering(sparsity_clustered_model)

print("Model sparsity:\n")
print_model_weights_sparsity(stripped_clustered_model)

print("\nModel clusters:\n")
print_model_weight_clusters(stripped_clustered_model)



Model sparsity:

kernel:0: 71.88% sparsity  (414/576)
kernel:0: 73.65% sparsity  (24134/32768)
kernel:0: 75.97% sparsity  (38899/51200)
kernel:0: 74.99% sparsity  (55286/73728)
kernel:0: 75.52% sparsity  (111354/147456)
kernel:0: 75.66% sparsity  (111561/147456)
kernel:0: 74.47% sparsity  (97607/131072)
kernel:0: 73.85% sparsity  (24199/32768)
kernel:0: 72.88% sparsity  (5970/8192)
kernel:0: 70.62% sparsity  (452/640)

Model clusters:

conv2d_10/kernel:0: 8 clusters 
conv2d_11/kernel:0: 8 clusters 
conv2d_12/kernel:0: 8 clusters 
conv2d_13/kernel:0: 8 clusters 
conv2d_14/kernel:0: 8 clusters 
conv2d_15/kernel:0: 8 clusters 
dense_11/kernel:0: 8 clusters 
dense_12/kernel:0: 8 clusters 
dense_13/kernel:0: 8 clusters 
dense_14/kernel:0: 8 clusters 


In [51]:

_, sparsity_clustered_model_accuracy = sparsity_clustered_model.evaluate(test_images, test_labels, verbose=0)

print('Clustered Model test accuracy:', sparsity_clustered_model_accuracy)
print("Clustered model size: ",  get_gzipped_model_size(sparsity_clustered_model) , ' KB' )


Clustered Model test accuracy: 0.9872000217437744
Clustered model size:  3094.786  KB


In [52]:
# Step 3.) Distillation 

In [53]:
from keras import layers
from keras import ops
import numpy as np

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
    ):
        teacher_pred = self.teacher(x, training=False)
        student_loss = self.student_loss_fn(y, y_pred)

        distillation_loss = self.distillation_loss_fn(
            ops.softmax(teacher_pred / self.temperature, axis=1),
            ops.softmax(y_pred / self.temperature, axis=1),
        ) * (self.temperature**2)

        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return loss

    def call(self, x):
        return self.student(x)

In [54]:
teacher = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        keras.layers.Conv2D(28, kernel_size=(5, 5), padding='same'),
        keras.layers.Activation('relu'),
        keras.layers.BatchNormalization(),
        keras.layers.Conv2D(28, kernel_size=(5, 5)),
        keras.layers.Activation('relu'),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Dropout(0.25),
        keras.layers.Conv2D(32, kernel_size=(5, 5), padding='same'),
        keras.layers.Activation('relu'),
        keras.layers.BatchNormalization(),
        keras.layers.Conv2D(32, kernel_size=(5, 5)),
        keras.layers.Activation('relu'),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Dropout(0.25),
        keras.layers.Flatten(),
        keras.layers.Dense(512),
        keras.layers.Activation('relu'),
        keras.layers.Dropout(0.25),
        keras.layers.Dense(10),
        keras.layers.Activation('softmax')
    ],
    name="teacher",
)


teacher.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)



In [55]:
teacher.fit(train_images, train_labels, epochs=3)
teacher.evaluate(test_images, test_labels)

Epoch 1/3


  output, from_logits = _get_logits(


Epoch 2/3
Epoch 3/3


[0.030768228694796562, 0.9904999732971191]

In [57]:
distiller = Distiller(student=stripped_clustered_model, teacher=teacher)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)




In [58]:
# -------------------------------- Pruning + Clustering + Distillation Done .... ----------------------------------------

In [59]:
# Distill teacher se student
distiller.fit(train_images, train_labels, epochs=3)

distiller.evaluate(test_images, test_labels)

Epoch 1/3
Epoch 2/3
Epoch 3/3


0.9858999848365784

In [1]:
stripped_clustered_model = tfmot.clustering.keras.strip_clustering(distiller.student)

print("Model sparsity:\n")
print_model_weights_sparsity(stripped_clustered_model)

print("\nModel clusters:\n")
print_model_weight_clusters(stripped_clustered_model) 


NameError: name 'tfmot' is not defined

## Apply QAT and PCQAT and check effect on model clusters and sparsity

Next, apply both QAT and PCQAT on the sparse clustered model and observe that PCQAT preserves weight sparsity and clusters in your model. Note that the stripped model is passed to the QAT and PCQAT API.

In [270]:
# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)

qat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

# PCQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
              stripped_clustered_model)
pcqat_model = tfmot.quantization.keras.quantize_apply(
              quant_aware_annotate_model,
              tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True))

pcqat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train pcqat model:')
pcqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

Train qat model:
Train pcqat model:










<tf_keras.src.callbacks.History at 0x16d73e250>

In [271]:
print("QAT Model clusters:")
print_model_weight_clusters(qat_model)
print("\nQAT Model sparsity:")
print_model_weights_sparsity(qat_model)
print("\nPCQAT Model clusters:")
print_model_weight_clusters(pcqat_model)
print("\nPCQAT Model sparsity:")
print_model_weights_sparsity(pcqat_model)

QAT Model clusters:
quant_conv2d_28/conv2d_28/kernel:0: 66 clusters 
quant_dense_19/dense_19/kernel:0: 11763 clusters 

QAT Model sparsity:
conv2d_28/kernel:0: 37.96% sparsity  (41/108)
dense_19/kernel:0: 31.06% sparsity  (6298/20280)

PCQAT Model clusters:
quant_conv2d_28/conv2d_28/kernel:0: 9 clusters 
quant_dense_19/dense_19/kernel:0: 8 clusters 

PCQAT Model sparsity:
conv2d_28/kernel:0: 70.37% sparsity  (76/108)
dense_19/kernel:0: 89.71% sparsity  (18193/20280)


## See compression benefits of PCQAT model

Define helper function to get zipped model file.

In [272]:
def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in kilobytes.

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000

Observe that applying sparsity, clustering and PCQAT to a model yields significant compression benefits.

In [273]:
# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
    f.write(qat_tflite_model)

# PCQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(pcqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
pcqat_tflite_model = converter.convert()
pcqat_model_file = 'pcqat_model.tflite'
# Save the model.
with open(pcqat_model_file, 'wb') as f:
    f.write(pcqat_tflite_model)

print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("PCQAT model size: ", get_gzipped_model_size(pcqat_model_file), ' KB')

INFO:tensorflow:Assets written to: /var/folders/px/z8lb6znd6q95tq6vlznyb0s40000gn/T/tmp41ty6m4p/assets


INFO:tensorflow:Assets written to: /var/folders/px/z8lb6znd6q95tq6vlznyb0s40000gn/T/tmp41ty6m4p/assets
W0000 00:00:1715835084.550317  341577 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1715835084.550831  341577 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
2024-05-16 10:21:24.551621: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/px/z8lb6znd6q95tq6vlznyb0s40000gn/T/tmp41ty6m4p
2024-05-16 10:21:24.552728: I tensorflow/cc/saved_model/reader.cc:51] Reading meta graph with tags { serve }
2024-05-16 10:21:24.552733: I tensorflow/cc/saved_model/reader.cc:146] Reading SavedModel debug info (if present) from: /var/folders/px/z8lb6znd6q95tq6vlznyb0s40000gn/T/tmp41ty6m4p
2024-05-16 10:21:24.562095: I tensorflow/cc/saved_model/loader.cc:234] Restoring SavedModel bundle.
2024-05-16 10:21:24.593714: I tensorflow/cc/saved_model/loader.cc:218] Running initialization op on SavedModel bundle at path: /var/folders/px/z8lb6znd

INFO:tensorflow:Assets written to: /var/folders/px/z8lb6znd6q95tq6vlznyb0s40000gn/T/tmp8ylysrw7/assets


INFO:tensorflow:Assets written to: /var/folders/px/z8lb6znd6q95tq6vlznyb0s40000gn/T/tmp8ylysrw7/assets


QAT model size:  17.169  KB
PCQAT model size:  4.578  KB


W0000 00:00:1715835085.429330  341577 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1715835085.429339  341577 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
2024-05-16 10:21:25.429439: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/px/z8lb6znd6q95tq6vlznyb0s40000gn/T/tmp8ylysrw7
2024-05-16 10:21:25.430709: I tensorflow/cc/saved_model/reader.cc:51] Reading meta graph with tags { serve }
2024-05-16 10:21:25.430715: I tensorflow/cc/saved_model/reader.cc:146] Reading SavedModel debug info (if present) from: /var/folders/px/z8lb6znd6q95tq6vlznyb0s40000gn/T/tmp8ylysrw7
2024-05-16 10:21:25.450602: I tensorflow/cc/saved_model/loader.cc:234] Restoring SavedModel bundle.
2024-05-16 10:21:25.482693: I tensorflow/cc/saved_model/loader.cc:218] Running initialization op on SavedModel bundle at path: /var/folders/px/z8lb6znd6q95tq6vlznyb0s40000gn/T/tmp8ylysrw7
2024-05-16 10:21:25.492152: I tensorflow/cc/saved_model/loader.cc:

## See the persistence of accuracy from TF to TFLite

Define a helper function to evaluate the TFLite model on the test dataset.

In [274]:
def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print(f"Evaluated on {i} results so far.")
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

Evaluate the model, which has been pruned, clustered and quantized, and then see that the accuracy from TensorFlow persists in the TFLite backend.

In [275]:
interpreter = tf.lite.Interpreter(pcqat_model_file)
interpreter.allocate_tensors()

pcqat_test_accuracy = eval_model(interpreter)

print('Pruned + clustered + Distilled quantized TFLite test_accuracy:', pcqat_test_accuracy)
print('Baseline TF test accuracy:', baseline_model_accuracy)

Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Pruned + clustered + Distilled quantized TFLite test_accuracy: 0.9594
Baseline TF test accuracy: 0.9814000129699707


## Conclusion