<a href="https://colab.research.google.com/github/GloC99/diagrams/blob/5CCSACCA/ModelPruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this lab you will use the tensorflow-model-optimization library to perform model quantisation and pruning on a simple DNN model that classifies hand-written digits from the MNIST dataset.

Let's first install the necessary libraries:

In [None]:
!pip install tensorflow
!pip install tensorflow-model-optimization

Collecting tensorflow-model-optimization
  Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl.metadata (904 bytes)
Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl (242 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensorflow-model-optimization
Successfully installed tensorflow-model-optimization-0.8.0


Now, let's download the MNIST dataset and perform some pre-processing for it:

In [None]:
from tensorflow_model_optimization.python.core.keras.compat import keras

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


Our next step is to build the structure of our DNN model, compile it and start the training process. Note that we will train the model for 5 epochs.

In [None]:
# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=5,
  validation_split=0.1,
)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tf_keras.src.callbacks.History at 0x7c727058ee30>

Once the model has been trained, let's evaluate its performance on the test data.

In [None]:
loss, accuracy = model.evaluate(test_images, test_labels)
print(accuracy)

0.979200005531311


Now that we have a high-performing model, we can proceed with applying model quantisation to it. For this we need to use the method quantize_model from the tensorflow_model_optimisation library. The following piece of code applies the quantize_model method to our trained model for digit classification and re-compiles the model.

In [None]:
import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model

# q_aware stands for quantization aware.
q_aware_model = quantize_model(model)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

Let's check how the quantisation has affected the accuracy of our model.

In [None]:
loss, q_aware_accuracy = q_aware_model.evaluate(test_images, test_labels)
print(q_aware_accuracy)

0.11349999904632568


As you can see the accuracy has gone significanly down. However, this is normal as we did not fine-tune our model after the quantisation. Let's fine-tune the model by training it for 1 epoch with 1000 images.

In [None]:
q_aware_model.fit(train_images[0:1000], train_labels[0:1000], epochs=1, batch_size=32)



<tf_keras.src.callbacks.History at 0x7c0a0ab8e5c0>

Let's now evaluate the accuracy of the quantised and fine-tuned model.

In [None]:
loss, q_aware_accuracy = q_aware_model.evaluate(test_images, test_labels)
print(q_aware_accuracy)

0.9805999994277954


We can see that the accuracy of the model has improved and is not that different from the original not-quantised model. How cool!

Now, let's check the size of the models:

In [None]:
# Save the original model (if you haven't pruned it yet)
model.save('original_model.h5')

# Save the pruned model (after stripping the pruning wrappers)
q_aware_model.save('q_aware_model.h5')

import os
original_size = os.path.getsize('original_model.h5')
quantised_size = os.path.getsize('q_aware_model.h5')

print(f"Original model size: {original_size / 1024:.2f} KB")
print(f"Quantised model size: {quantised_size / 1024:.2f} KB")
print(f"Size reduction: {(1 - quantised_size / original_size) * 100:.2f}%")

Original model size: 265.97 KB
Quantised model size: 282.87 KB
Size reduction: -6.35%


  saving_api.save_model(


You can see that the file size does not decrease much or even increases. This is because we have not yet performed all the necessary steps to complete the quantisation.

Let's now use TFLiteConverter to convert our quantised model into TFLite format with default optimisation configurations:

In [None]:
import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()



Let's now implement the code that would allow us to evaluate the TFLite model. We can not use standart evaluate method for TFLite, as TFLite operates differently from TensorFlow/Keras during inference. More specifically, a TFLite model is loaded and executed using a TFLite interpreter, which operates on a lightweight, deployment-friendly runtime designed for inference only. It does not have built-in functions for model evaluation, like model.evaluate in Keras.

In [None]:
import numpy as np

def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy


Now, let's evaluate the TFLite model we have with the implemented evaluation method.

In [None]:
interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)

print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_accuracy)




Quant TFLite test_accuracy: 0.9806
Quant TF test accuracy: 0.9805999994277954


Now, let's save the original and quantised models and see how different are their sizes.

In [None]:
import tempfile

# Create float TFLite model.
float_converter = tf.lite.TFLiteConverter.from_keras_model(model)
float_tflite_model = float_converter.convert()

# Measure sizes of models.
_, float_file = tempfile.mkstemp('.tflite')
_, quant_file = tempfile.mkstemp('.tflite')

with open(quant_file, 'wb') as f:
  f.write(quantized_tflite_model)

with open(float_file, 'wb') as f:
  f.write(float_tflite_model)

print("Float model in Mb:", os.path.getsize(float_file) / float(2**20))
print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))

Float model in Mb: 0.08073043823242188
Quantized model in Mb: 0.023681640625


We can see that the model size reduces by around 4 times, without any significant loss in accuracy. This concludes the part of this lab which focuses on model quantisation.


Let's know use the tensorflow-model-optimization to perform model pruning. We first define a pruning schedule. Here are the parameters we use:

**initial_sparsity**: The fraction of weights set to zero at the start of pruning.

**final_sparsity**: The fraction of weights to be zeroed out by the end of pruning.

**begin_step**: The training step at which pruning begins.

**end_step**: The training step at which pruning ends.

**frequency**: The interval (in steps) at which pruning is applied.

In [None]:
# Define the pruning schedule
pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
    initial_sparsity=0.0,  # Start with no sparsity
    final_sparsity=0.5,    # Target 50% sparsity
    begin_step=0,
    end_step=2000,  # Adjust this value depending on the total number of steps
    frequency=100   # Apply pruning every 100 steps
)

We now wrap the model to include pruning logic and re-compile the model.

In [None]:
# Apply pruning to the model
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(model, pruning_schedule=pruning_schedule)

# Compile the pruned model
pruned_model.compile(
    optimizer='adam',
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

We create a pruning-specific callback to update the pruning step during training. We then fine-tune the pruned model, gradually increasing sparsity as per the schedule while monitoring accuracy.

In [None]:
# Create pruning callback
pruning_callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep()
]

# Fine-tune the pruned model with the pruning callback
pruned_model.fit(
    train_images,
    train_labels,
    epochs=5,
    validation_split=0.1,
    callbacks=pruning_callbacks
)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tf_keras.src.callbacks.History at 0x7c0a36f208e0>

We then strip the pruning wrappers, save the final pruned model, compile it and evaluate its performance.

In [None]:
# Strip the pruning wrappers and save the final pruned model
final_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

final_model.compile(
    optimizer='adam',
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

# Now you can proceed with evaluation
pruned_accuracy = final_model.evaluate(test_images, test_labels, verbose=0)[1]

print(f"Pruned model accuracy: {pruned_accuracy * 100:.2f}%")


Pruned model accuracy: 98.13%


We can see that the model's accuracy remains almost the same. Now, let's see how the size of the model changes.

In [None]:
# Save the original model (if you haven't pruned it yet)
model.save('original_model.h5')

# Save the pruned model (after stripping the pruning wrappers)
final_model.save('pruned_model.h5')

import os
original_size = os.path.getsize('original_model.h5')
pruned_size = os.path.getsize('pruned_model.h5')

print(f"Original model size: {original_size / 1024:.2f} KB")
print(f"Pruned model size: {pruned_size / 1024:.2f} KB")
print(f"Size reduction: {(1 - pruned_size / original_size) * 100:.2f}%")

Original model size: 265.97 KB
Pruned model size: 98.03 KB
Size reduction: 63.14%


We can see that the size reduces by around 65% without any significant loss in accuracy.