## Quantization and Pruning

This notebook will cover mobile optimization techniques quantization and pruning. These techniques enable reduced model size and latency whic make it ideal for edge and IOT devices.

In [1]:
import tensorflow as tf
import numpy as np
import os
import tempfile
import zipfile

In [8]:
# GLOBAL VARIABLES

# String constants for model filenames
FILE_WEIGHTS = 'baseline.weights.h5'
FILE_NON_QUANTIZED_H5 = 'non_quantized.h5'
FILE_NON_QUANTIZED_TFLITE = 'non_quantized.tflite'
FILE_PT_QUANTIZED = 'post_training_quantized.tflite'
FILE_QAT_QUANTIZED = 'quant_aware_quantized.tflite'
FILE_PRUNED_MODEL_H5 = 'pruned_model.h5'
FILE_PRUNED_QUANTIZED_TFLITE = 'pruned_quantized.tflite'
FILE_PRUNED_NON_QUANTIZED_TFLITE = 'pruned_non_quantized.tflite'

In [3]:
# Dictionaries to hold measurements
MODEL_SIZE = {}
ACCURACY = {}

In [15]:
def print_metrics(metric_dict, metric_name):
  '''Prints key and values stored in a dictionary'''
  for metric, value in metric_dict.items():
    print(f'{metric_name} for {metric}: {value}')


def model_builder():
  '''Returns a CNN for training on MNIST dataset'''

  keras = tf.keras

  # Define the model architecture
  model = keras.Sequential([
      keras.layers.InputLayer(shape=(28,28)),
      keras.layers.Reshape(target_shape=(28,28,1)),
      keras.layers.Conv2D(filters=12, kernel_size=(3,3), activation='relu'),
      keras.layers.MaxPooling2D(pool_size=(2,2)),
      keras.layers.Flatten(),
      keras.layers.Dense(units=10, activation='softmax')
  ])

  return model


def evaluate_tflite_model(filename, x_test, y_test):
  '''
  Measures accuracy of a given TF Lite model on test set

  Args:
    filename: path to the TF Lite model
    x_test (numpy array): test images
    y_test (numpy array): test labels

  Returns:
    accuracy: accuracy of the model
  '''

  # Initialize the TF lite Interpreter and allocate tensors
  interpreter = tf.lite.Interpreter(model_path=filename)
  interpreter.allocate_tensors()

  # Get input and output tensors
  input_index = interpreter.get_input_details()[0]['index']
  output_index = interpreter.get_output_details()[0]['index']

  # Initialize empty predictions list
  prediction_digits = []

  # Run predictions on every image in the "test" dataset.
  for i, test_image in enumerate(x_test):
    # Preprocessing: add batch dimension and convert to float32 to match with the models input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference
    interpreter.invoke()

    # Postprocessing: remove batch dimension and find the digit with highest probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == y_test).mean()

  return accuracy

def get_gzipped_model_size(file):
  '''Return size of gzipped model in bytes.'''
  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)

In [5]:
# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize the input image so that each pixel value is in [0,1] range.
x_train, x_test = x_train / 255.0, x_test / 255.

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


In [11]:
# Create the baseline model
baseline_model = model_builder()

# Save the initial weights for later use
baseline_model.save_weights(FILE_WEIGHTS)

# Print the model summary
baseline_model.summary()

In [12]:
# Setup the model for training
baseline_model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Train the model
baseline_model.fit(x_train, y_train, epochs=1, shuffle=False)

[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 10ms/step - accuracy: 0.8656 - loss: 0.5122


<keras.src.callbacks.history.History at 0x7a9459c84bb0>

In [14]:
# Get the baseline accuracy
_, ACCURACY['baseline Keras model'] = baseline_model.evaluate(x_test, y_test)

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.9520 - loss: 0.1616


In [16]:
# Save the keras model
baseline_model.save(FILE_NON_QUANTIZED_H5, include_optimizer=False)

# Save and get the model size
MODEL_SIZE['baseline h5'] = os.path.getsize(FILE_NON_QUANTIZED_H5)

# Print records so far
print_metrics(ACCURACY, 'Accuracy')
print_metrics(MODEL_SIZE, 'Model size')



Accuracy for baseline Keras model: 0.9595000147819519
Model size for baseline h5: 100992




Next, we will convert the model to Tensorflow Lite (TF Lite) format. This is designed to make Tensorflow models more efficient and lightweight when running on mobile, embedded and IOT devices.

In [17]:
def convert_tflite(model, filename, quantize=False):
  '''
  Convert the model to TF Lite format and write to a file

  Args:
    model (Keras model): model to convert
    filename (str): path to write
    quantize (bool): whether to quantize the model

  Returns:
    None
  '''

  # Initialize the converter
  converter = tf.lite.TFLiteConverter.from_keras_model(model)

  # Set for quantization if flag is True
  if quantize:
    converter.optimizations = [tf.lite.Optimize.DEFAULT]

  # Convert the model
  tflite_model = converter.convert()

  # Save the model
  with open(filename, 'wb') as f:
    f.write(tflite_model)

In [18]:
# Convert baseline model
convert_tflite(baseline_model, FILE_NON_QUANTIZED_TFLITE)

Saved artifact at '/tmp/tmp0b5qa70v'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28), dtype=tf.float32, name='keras_tensor_18')
Output Type:
  TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)
Captures:
  134777580069536: TensorSpec(shape=(), dtype=tf.resource, name=None)
  134777580073056: TensorSpec(shape=(), dtype=tf.resource, name=None)
  134777580072704: TensorSpec(shape=(), dtype=tf.resource, name=None)
  134777580074816: TensorSpec(shape=(), dtype=tf.resource, name=None)


In [20]:
MODEL_SIZE['non quantized tflite'] = os.path.getsize(FILE_NON_QUANTIZED_TFLITE)

print_metrics(MODEL_SIZE, 'model size in bytes')

model size in bytes for baseline h5: 100992
model size in bytes for non quantized tflite: 84912


There is already a slight decrease in model size when converting to '.tflite' format. The accuracy will also be almost identical between these two.

In [24]:
ACCURACY['non quantized tflite'] = evaluate_tflite_model(FILE_NON_QUANTIZED_TFLITE, x_test, y_test)

print_metrics(ACCURACY, 'test accuracy')

test accuracy for baseline Keras model: 0.9595000147819519
test accuracy for non quantized tflite: 0.9595


In [None]:
# Convert and quantize the baseline model
convert_tflite(baseline_model, FILE_PT_QUANTIZED, quantize=True)

Now that you we have the baseline metrics, we can observe the effects of quantization. This process involves converting floating point representations into integer to reduce model size and achieve faster computation.

In [23]:
# Get the model size
MODEL_SIZE['post training quantized tflite'] = os.path.getsize(FILE_PT_QUANTIZED)

print_metrics(MODEL_SIZE, 'model size in bytes')

model size in bytes for baseline h5: 100992
model size in bytes for non quantized tflite: 84912
model size in bytes for post training quantized tflite: 24264


We can see that there is around 4X reduction in model size with quantized version. This comes from converting the 32 bit representations (floats) to 8 bits (integer). Also, accuracy stays almost the same in this case. You can expect it to be lower usually but in some cases it can even increase.

In [25]:
ACCURACY['post training quantized tflite'] = evaluate_tflite_model(FILE_PT_QUANTIZED, x_test, y_test)

print_metric(ACCURACY, 'test accuracy')

test accuracy for baseline Keras model: 0.9595000147819519
test accuracy for non quantized tflite: 0.9595
test accuracy for post training quantized tflite: 0.9596


In [None]:
!pip install tensorflow_model_optimization

Another technique for reducing model size is Pruning. This process involves zeroing out isignificant (i.e. low magnitude) weights. Idea is that these weights do not contribute that much to making predictions so we can remove them and get the same results. Making the weights sparse helps is compressing the model more efficiently.

The Tensorflow Model Optimization Toolkit again has a convenience method for this. The prune_low_magnitude() method puts wrappers in a Keras model so it can be pruned during training. We will pass in the baseline model that we already trained earlier. We will notice that the model summary show increased params because of the wrapper layers added by the pruning method.

In [None]:
# Get the pruning method
import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size=128
epochs=2
validation_split=0.1

num_images = x_train.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define pruning schedule
pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                             final_sparsity=0.80,
                                                             begin_step=0,
                                                             end_step = end_step)
}

# Pass in the trained baseline model
model_for_pruning = prune_low_magnitude(baseline_model, **pruning_params)

# 'prune_low_magnitude' requires a recompile
model_for_pruning.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model_for_pruning.summary()

In [None]:
# Preview model weights
model_for_pruning.weights[1]

<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 12) dtype=float32, numpy=
array([[[[-0.65388936, -0.37027386, -0.5895651 ,  0.25797945,
           0.02541543, -0.6621699 ,  0.15978512, -0.03010637,
           0.26917607,  0.52216   ,  0.34164882, -0.14870733]],

        [[-0.47334045,  0.07472505, -0.5171312 ,  0.41284287,
           0.23645961, -0.08986535,  0.02500086,  0.2895436 ,
           0.2700939 ,  0.6329086 , -0.09309151, -0.2279    ]],

        [[-0.4112018 ,  0.26842037, -0.5538995 ,  0.6393494 ,
           0.1891551 ,  0.5963347 ,  0.11188681, -0.00722711,
           0.23119628,  0.41641408, -0.34437463, -0.00086102]]],


       [[[-0.16584733, -0.26762313, -0.10571435, -0.5680298 ,
           0.21757409, -0.6163616 , -0.1669713 ,  0.186569  ,
           0.01832947,  0.085479  ,  0.15577339,  0.28355384]],

        [[-0.03654844,  0.09806405,  0.16673827, -0.14484172,
           0.04716373,  0.12461241,  0.13102461,  0.19672811,
           0.02712863, -0.04133146,  0.111033

In [None]:
# Callback to update pruning wrappers at each step
callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]

# Train and prune the model
model_for_pruning.fit(x_train, y_train, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

Epoch 1/2
Epoch 2/2


<keras.src.callbacks.History at 0x7834d56d0550>

In [None]:
# Preview model weights after pruning (zeroing out)
model_for_pruning.weights[1]

<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 12) dtype=float32, numpy=
array([[[[-1.2199291 ,  0.        , -1.2455136 ,  0.        ,
          -0.        , -1.0761017 ,  0.        ,  0.        ,
           0.        ,  0.8754954 ,  0.7317014 ,  0.        ]],

        [[-1.0470268 ,  0.        , -0.61345565,  0.67668736,
          -0.        , -0.        ,  0.        ,  0.        ,
           0.        ,  0.99224913, -0.        ,  0.        ]],

        [[-0.        ,  0.        , -0.6936322 ,  0.9042901 ,
          -0.        ,  0.9175063 ,  0.        ,  0.        ,
           0.        ,  0.        , -0.        ,  0.        ]]],


       [[[ 0.        ,  0.        ,  0.        , -1.0346237 ,
          -0.        , -0.8650951 ,  0.        ,  0.        ,
           0.        ,  0.        , -0.        ,  0.7810932 ]],

        [[ 0.        ,  0.        ,  0.        ,  0.        ,
          -0.        , -0.        ,  0.        ,  0.        ,
           0.        ,  0.        , -0.      

In [None]:
# Remove pruning wrappers
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
model_for_export.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 reshape (Reshape)           (None, 28, 28, 1)         0         
                                                                 
 conv2d (Conv2D)             (None, 26, 26, 12)        120       
                                                                 
 max_pooling2d (MaxPooling2  (None, 13, 13, 12)        0         
 D)                                                              
                                                                 
 flatten (Flatten)           (None, 2028)              0         
                                                                 
 dense (Dense)               (None, 10)                20290     
                                                                 
Total params: 20410 (79.73 KB)
Trainable params: 20410 (79.73 KB)
Non-trainable params: 0 (0.00 Byte)
____________________

In [None]:
# Preview model weights (index 1 earlier in 0 now because of pruning wrapper removal)
model_for_export.weights[0]

<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 12) dtype=float32, numpy=
array([[[[-1.2199291 ,  0.        , -1.2455136 ,  0.        ,
          -0.        , -1.0761017 ,  0.        ,  0.        ,
           0.        ,  0.8754954 ,  0.7317014 ,  0.        ]],

        [[-1.0470268 ,  0.        , -0.61345565,  0.67668736,
          -0.        , -0.        ,  0.        ,  0.        ,
           0.        ,  0.99224913, -0.        ,  0.        ]],

        [[-0.        ,  0.        , -0.6936322 ,  0.9042901 ,
          -0.        ,  0.9175063 ,  0.        ,  0.        ,
           0.        ,  0.        , -0.        ,  0.        ]]],


       [[[ 0.        ,  0.        ,  0.        , -1.0346237 ,
          -0.        , -0.8650951 ,  0.        ,  0.        ,
           0.        ,  0.        , -0.        ,  0.7810932 ]],

        [[ 0.        ,  0.        ,  0.        ,  0.        ,
          -0.        , -0.        ,  0.        ,  0.        ,
           0.        ,  0.        , -0.      

We can notice that the pruned model has the same file size as the baseline_model when saved as H5. This is to be expected. The improvement can be seen when compressing the model.

In [None]:
# Save Keras model
model_for_export.save(FILE_PRUNED_MODEL_H5, include_optimizer=False)

# Get uncompressed size of baseline and pruned models
MODEL_SIZE = {}
MODEL_SIZE['baseline h5'] = os.path.getsize(FILE_NON_QUANTIZED_H5)
MODEL_SIZE['pruned h5'] = os.path.getsize(FILE_PRUNED_MODEL_H5)

print_metric(MODEL_SIZE, 'model size in bytes')

  saving_api.save_model(


model size in bytes for baseline h5: 98968
model size in bytes for pruned h5: 98968


Compressed pruned model is about 3 times smaller than baseline.

In [None]:
# Get compressed size of baseline and pruned models
MODEL_SIZE = {}
MODEL_SIZE['baseline h5'] = get_gzipped_model_size(FILE_NON_QUANTIZED_H5)
MODEL_SIZE['pruned non quantized h5'] = get_gzipped_model_size(FILE_PRUNED_MODEL_H5)

print_metric(MODEL_SIZE, "gzipped model size in bytes")

gzipped model size in bytes for baseline h5: 78070
gzipped model size in bytes for pruned non quantized h5: 25993


We can make the model even more lighweight by quantizing the pruned model in .tflite format. This results in around 10X reduction in compressed model size as compared to the baseline.

In [None]:
# Convert and quantize the pruned model
pruned_quantized_tflite = convert_tflite(model_for_export, FILE_PRUNED_QUANTIZED_TFLITE, quantize=True)

# Compress and get the model size
MODEL_SIZE['pruned quantized tflite'] = get_gzipped_model_size(FILE_PRUNED_QUANTIZED_TFLITE)
print_metric(MODEL_SIZE, "gzipped model size in bytes")

gzipped model size in bytes for baseline h5: 78070
gzipped model size in bytes for pruned non quantized h5: 25993
gzipped model size in bytes for pruned quantized tflite: 8217


Accuracy is

In [None]:
# Get accuracy of pruned Keras and TF Lite models
ACCURACY = {}

_, ACCURACY['pruned model h5'] = model_for_pruning.evaluate(x_test, y_test)
ACCURACY['pruned and quantized tflite'] = evaluate_tflite_model(FILE_PRUNED_QUANTIZED_TFLITE, x_test, y_test)

print_metric(ACCURACY, 'accuracy')

accuracy for pruned model h5: 0.9704999923706055
accuracy for pruned and quantized tflite: 0.9703
