   # The Ultimate Compression Pipeline ~ Ajay Maheshwari

In [100]:
! pip install -q tensorflow-model-optimization

In [101]:
import tensorflow as tf
import tf_keras as keras

import numpy as np
import tempfile
import zipfile
import os

In [102]:
def get_gzipped_model_size(model):
  with tempfile.NamedTemporaryFile(suffix=".h5") as temp_file:  
    model.save(temp_file.name)

    _, zipped_file = tempfile.mkstemp('.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
      f.write(temp_file.name)
    
    # print(f"Zipped model is saved at: {zipped_file}")

    x = os.path.getsize(zipped_file)
    
    os.remove(zipped_file)
    # print(f"Temporary zip file removed: {zipped_file}")
        
    return x / 1000

def get_model_multiplications(model):
    # Calculate the number of multiplications needed for inference
    num_multiplications = calculate_multiplications(model)
    return num_multiplications

def calculate_multiplications(model):
    # Calculate the number of multiplications needed for inference
    num_multiplications = 0
    for layer in model.layers:
        if hasattr(layer, 'get_weights'):
            weights = layer.get_weights()
            if weights:
                num_params = sum(w.size for w in weights)
                num_nonzero_params = sum(1 for w in weights for val in w.flatten() if val != 0)
                num_multiplications += num_nonzero_params  # Assuming each non-zero weight multiplication as one operation
    return num_multiplications

def print_model_weights_sparsity(model):
    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            if "kernel" not in weight.name or "centroid" in weight.name:
                continue
            weight_size = weight.numpy().size
            zero_num = np.count_nonzero(weight == 0)
            print(
                f"{weight.name}: {zero_num/weight_size:.2%} sparsity ",
                f"({zero_num}/{weight_size})",
            )


def get_gzipped_model_size2(file):

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000


def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  prediction_digits = []
  for i, test_image in enumerate(test_images):
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    interpreter.invoke()

    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

model_acc = []
model_sz = []

## Creating a Base Model - 1.0

In [103]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

opt = keras.optimizers.Adam(learning_rate=1e-3)

# Train the digit classification model
model.compile(optimizer=opt,
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)



Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tf_keras.src.callbacks.History at 0x17eabad50>

### Evaluate the baseline model and save it for later usage

In [104]:
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

sz = get_gzipped_model_size(model)
print("Base model size: ",  sz , ' KB' )

model_acc.append(baseline_model_accuracy)
model_sz.append(sz)

print("Total Multiplications = ", get_model_multiplications(model))

Baseline test accuracy: 0.9825000166893005
Base model size:  234.936  KB
Total Multiplications =  20410


                         ---------- Checkpoint Point 1 ---------

In [66]:
from keras import layers
from keras import ops
import numpy as np

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
    ):
        teacher_pred = self.teacher(x, training=False)
        student_loss = self.student_loss_fn(y, y_pred)

        distillation_loss = self.distillation_loss_fn(
            ops.softmax(teacher_pred / self.temperature, axis=1),
            ops.softmax(y_pred / self.temperature, axis=1),
        ) * (self.temperature**2)

        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return loss

    def call(self, x):
        return self.student(x)
    
    
    
teacher = keras.Sequential(
    [
        keras.layers.InputLayer(input_shape=(28, 28)),
        keras.layers.Reshape(target_shape=(28, 28, 1)),
        keras.layers.Conv2D(filters=16, kernel_size=(3, 3), activation=tf.nn.relu),  # Increase filters
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Conv2D(filters=24, kernel_size=(3, 3), activation=tf.nn.relu),  # Add another layer
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dense(units=128, activation=tf.nn.relu),  # Add a hidden layer
        keras.layers.Dense(10)
    ],
    name="teacher",
)


# teacher = keras.Sequential(
#     [
#           keras.layers.InputLayer(input_shape=(28, 28)),
#           keras.layers.Reshape(target_shape=(28, 28, 1)),
#           keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
#                                  activation=tf.nn.relu),
#           keras.layers.MaxPooling2D(pool_size=(2, 2)),
#           keras.layers.Flatten(),
#           keras.layers.Dense(10)
#     ],
#     name="teacher",
# )


teacher.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

teacher.fit(train_images, train_labels, epochs=5)
teacher.evaluate(test_images, test_labels)




Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


[0.032516371458768845, 0.989799976348877]

## Pruning and then fine-tuning the model - 2.0

In [105]:
import tensorflow_model_optimization as tfmot

callbacks = [
      tfmot.sparsity.keras.UpdatePruningStep()
]    
    
pruned_model = model


for i in range(1):
    
    pruning_params = {
          'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0, frequency=100)     
    }
    
    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

    pruned_model = prune_low_magnitude(pruned_model, **pruning_params)

    # learning rate for fine-tuning
    opt = keras.optimizers.Adam(learning_rate=1e-5)

    pruned_model.compile(
      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=opt,
      metrics=['accuracy'])
    
    # Fine-tune model
    pruned_model.fit(
      train_images,
      train_labels,
      epochs=5,
      validation_split=0.1,
      callbacks=callbacks)
    
    
    stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)
    
    save_directory = '/Users/ajaymaheshwari/Desktop/Models/exp/'
    
    model_path = os.path.join(save_directory, f"modelPruned{i}.h5")
    model.save(model_path)
    
    print_model_weights_sparsity(stripped_pruned_model)
    
    
    



Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
conv2d_7/kernel:0: 50.00% sparsity  (54/108)
dense_7/kernel:0: 50.00% sparsity  (10140/20280)


In [106]:
pruned_model.compile(
      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=opt,
      metrics=['accuracy'])

In [97]:
# from keras.models import load_model
# import os
# from keras.losses import SparseCategoricalCrossentropy

# # Custom wrapper to handle the deserialization of the loss function
# class CustomSparseCategoricalCrossentropy(SparseCategoricalCrossentropy):
#     def __init__(self, **kwargs):
#         if 'fn' in kwargs:
#             del kwargs['fn']
#         # Ensure the 'reduction' parameter is set to a valid value
#         if kwargs.get('reduction') == 'auto':
#             kwargs['reduction'] = 'sum_over_batch_size'
#         super().__init__(**kwargs)

# # Register the custom loss function
# custom_objects = {
#     'SparseCategoricalCrossentropy': CustomSparseCategoricalCrossentropy
# }

# # Define the directory and the model file name
# save_directory = '/Users/ajaymaheshwari/Desktop/Models/exp/'
# model_path = os.path.join(save_directory, "modelPruned0.h5")

# # Load the model with the custom objects
# loaded_model = load_model(model_path, custom_objects=custom_objects)

# # Use the model (for demonstration, we just print its summary)
# loaded_model.summary()


In [98]:


# distiller = Distiller(student=stripped_pruned_model, teacher=teacher)
# distiller.compile(
#         optimizer=keras.optimizers.Adam(),
#         metrics=[keras.metrics.SparseCategoricalAccuracy()],
#         student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
#         distillation_loss_fn=keras.losses.KLDivergence(),
#         alpha=0.1,
#         temperature=10,
#     )

# # Distill teacher se student
# distiller.fit(train_images, train_labels, epochs=2)

# distiller.evaluate(test_images, test_labels)

# pruned_model = distiller.student


In [71]:
pruned_model.compile(
      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=opt,
      metrics=['accuracy'])

In [107]:
_, pruned_model_accuracy = pruned_model.evaluate(
    test_images, test_labels, verbose=0)

print('Pruned Model test accuracy:', pruned_model_accuracy)

sz = get_gzipped_model_size(pruned_model)
print("Stripped model size: ",  sz , ' KB' )

model_acc.append(pruned_model_accuracy)
model_sz.append(sz)

print("Total Multiplications = ", get_model_multiplications(pruned_model))

Pruned Model test accuracy: 0.9763000011444092
Stripped model size:  211.454  KB
Total Multiplications =  10216


                         ---------- Checkpoint Point 2 ---------

## Knowledge Distillation - 3.0

In [73]:
weights = pruned_model.get_weights()
print(len(weights))

4


In [74]:
def svd_compress(weights, rank):
    compressed_weights = []
    for i in range(0, len(weights), 2): 
        W = weights[i]
        b = weights[i + 1]
        if len(W.shape) > 2:  
            W_reshaped = tf.reshape(W, [W.shape[0] * W.shape[1] * W.shape[2], W.shape[3]])
            s, U, V = tf.linalg.svd(W_reshaped, full_matrices=False)
            U_r, s_r, V_r = U[:, :rank], tf.linalg.diag(s[:rank]), tf.transpose(V[:, :rank])
            W_r = tf.matmul(tf.matmul(U_r, s_r), V_r)
            W_r = tf.reshape(W_r, W.shape)
            compressed_weights.append(W_r)
        else: 
            s, U, V = tf.linalg.svd(W, full_matrices=False)
            U_r, s_r, V_r = U[:, :rank], tf.linalg.diag(s[:rank]), tf.transpose(V[:, :rank])
            W_r = tf.matmul(tf.matmul(U_r, s_r), V_r)
            compressed_weights.append(W_r)
        compressed_weights.append(b) 
    return compressed_weights


In [75]:


# Get the weights of the pruned model
weights = pruned_model.get_weights()

# Compress the weights
rank = 100
compressed_weights = svd_compress(weights, rank)

In [76]:
# print((compressed_weights[0]))
# print(weights[0])

In [77]:
_, pruned_model_accuracy = pruned_model.evaluate(
    test_images, test_labels, verbose=0)

print('Pruned Model test accuracy:', pruned_model_accuracy)

sz = get_gzipped_model_size(pruned_model)
print("Stripped model size: ",  sz , ' KB' )

model_acc.append(pruned_model_accuracy)
model_sz.append(sz)


Pruned Model test accuracy: 0.9686999917030334
Stripped model size:  228.5  KB


In [78]:
# _, acc = distiller.student.evaluate(
#     test_images, test_labels, verbose=0)

# print('Distilled Model test accuracy:', acc)

# sz = get_gzipped_model_size(distiller.student)
# print("Distilled model size: ",  sz , ' KB' )

# model_acc.append(acc)
# model_sz.append(sz)

                         ---------- Checkpoint Point 3 ---------

## Weight Clustering - 4.0

In [79]:
def print_model_weight_clusters(model):
    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            if "kernel" in weight.name:
                unique_count = len(np.unique(weight))
                print(
                    f"{layer.name}/{weight.name}: {unique_count} clusters "
                )

In [80]:
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.clustering.keras.experimental import (
    cluster,
)

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

cluster_weights = cluster.cluster_weights

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
  'preserve_sparsity': True
}

sparsity_clustered_model = cluster_weights(pruned_model, **clustering_params)

sparsity_clustered_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

print('Train sparsity preserving clustering model:')
sparsity_clustered_model.fit(train_images, train_labels,epochs=6, validation_split=0.1)

Train sparsity preserving clustering model:
Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


<tf_keras.src.callbacks.History at 0x172ba7b50>

In [81]:
stripped_clustered_model = tfmot.clustering.keras.strip_clustering(sparsity_clustered_model)

print("Model sparsity:\n")
print_model_weights_sparsity(stripped_clustered_model)

print("\nModel clusters:\n")
print_model_weight_clusters(stripped_clustered_model)

Model sparsity:

kernel:0: 32.41% sparsity  (35/108)
kernel:0: 60.07% sparsity  (12182/20280)

Model clusters:

conv2d_3/kernel:0: 8 clusters 
dense_3/kernel:0: 8 clusters 


In [82]:
stripped_clustered_model.compile(optimizer=opt,
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [83]:
_, stripped_clustered_model_accuracy = stripped_clustered_model.evaluate(
    test_images, test_labels, verbose=0)

print('Clustered Model test accuracy:', stripped_clustered_model_accuracy)

sz = get_gzipped_model_size(stripped_clustered_model)
print("Clustered model size: ",  sz , ' KB' )

model_acc.append(stripped_clustered_model_accuracy)
model_sz.append(sz)

Clustered Model test accuracy: 0.9771000146865845
Clustered model size:  166.025  KB


                         ---------- Checkpoint Point 4 ---------

## Quantization - 5.0

In [84]:
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
              stripped_clustered_model)
quant_model = tfmot.quantization.keras.quantize_apply(
              quant_aware_annotate_model,
              tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True))

quant_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Training after quantization model:')
quant_model.fit(train_images, train_labels, batch_size=128, epochs=3, validation_split=0.1)

Training after quantization model:
Epoch 1/3








Epoch 2/3
Epoch 3/3


<tf_keras.src.callbacks.History at 0x172b19650>

In [85]:
print("Final Model clusters:")
print_model_weight_clusters(quant_model)
print("\nFinal Model sparsity:")
print_model_weights_sparsity(quant_model)

Final Model clusters:
quant_conv2d_3/conv2d_3/kernel:0: 8 clusters 
quant_dense_3/dense_3/kernel:0: 8 clusters 

Final Model sparsity:
conv2d_3/kernel:0: 34.26% sparsity  (37/108)
dense_3/kernel:0: 60.35% sparsity  (12239/20280)


In [86]:
converter = tf.lite.TFLiteConverter.from_keras_model(quant_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
final_tflite_model = converter.convert()
final_model_file = 'final_model.tflite'
# Save the model.
with open(final_model_file, 'wb') as f:
    f.write(final_tflite_model)


sz = get_gzipped_model_size2(final_model_file)
print("Final model size: ", sz, ' KB')

model_sz.append(sz)

INFO:tensorflow:Assets written to: /var/folders/px/z8lb6znd6q95tq6vlznyb0s40000gn/T/tmp2z054_n3/assets


INFO:tensorflow:Assets written to: /var/folders/px/z8lb6znd6q95tq6vlznyb0s40000gn/T/tmp2z054_n3/assets


Final model size:  7.173  KB


W0000 00:00:1716868683.403197   14434 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1716868683.403743   14434 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.
2024-05-28 09:28:03.406097: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/px/z8lb6znd6q95tq6vlznyb0s40000gn/T/tmp2z054_n3
2024-05-28 09:28:03.407396: I tensorflow/cc/saved_model/reader.cc:51] Reading meta graph with tags { serve }
2024-05-28 09:28:03.407401: I tensorflow/cc/saved_model/reader.cc:146] Reading SavedModel debug info (if present) from: /var/folders/px/z8lb6znd6q95tq6vlznyb0s40000gn/T/tmp2z054_n3
2024-05-28 09:28:03.416944: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:388] MLIR V1 optimization pass is not enabled
2024-05-28 09:28:03.417761: I tensorflow/cc/saved_model/loader.cc:234] Restoring SavedModel bundle.
2024-05-28 09:28:03.452701: I tensorflow/cc/saved_model/loader.cc:218] Running initialization op on SavedModel bundle at 

In [87]:
interpreter = tf.lite.Interpreter(final_model_file)
interpreter.allocate_tensors()
 
final_test_accuracy = eval_model(interpreter)

print('Final test accuracy:', final_test_accuracy)

model_acc.append(final_test_accuracy)

I tensorflow/cc/saved_model/loader.cc:317] SavedModel load for tags { serve }; Status: success: OK. Took 55657 microseconds.
2024-05-28 09:28:03.495906: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


Final test accuracy: 0.9803


                         ---------- Checkpoint Point 5 ---------

In [88]:
for i in range(len(model_acc)):
    print(f"Accuracy = { round(model_acc[i]*100,2)} with size = {model_sz[i]} KB ")

Accuracy = 98.01 with size = 235.559 KB 
Accuracy = 96.87 with size = 228.5 KB 
Accuracy = 96.87 with size = 228.5 KB 
Accuracy = 97.71 with size = 166.025 KB 
Accuracy = 98.03 with size = 7.173 KB 


                         ------- Final Comparison Summary -------