In [1]:
import tensorflow as tf
import numpy as np
import os
import tempfile
import zipfile

#### Workflow
- post-training quantization
- quantization aware training
- weight pruning

In [2]:
# String constants for model filenames
FILE_WEIGHTS = 'baseline_weights.h5'
FILE_NON_QUANTIZED_H5 = 'non_quantized.h5'
FILE_NON_QUANTIZED_TFLITE = 'non_quantized.tflite'
FILE_PT_QUANTIZED = 'post_training_quantized.tflite'
FILE_QAT_QUANTIZED = 'quant_aware_quantized.tflite'
FILE_PRUNED_MODEL_H5 = 'pruned_model.h5'
FILE_PRUNED_QUANTIZED_TFLITE = 'pruned_quantized.tflite'
FILE_PRUNED_NON_QUANTIZED_TFLITE = 'pruned_non_quantized.tflite'

# Dictionaries to hold measurements
MODEL_SIZE = {}
ACCURACY = {}

#### Load Dataset

In [3]:
# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

#### Utilities

In [4]:
def print_metric(metric_dict, metric_name):
  '''Prints key and values stored in a dictionary'''
  for metric, value in metric_dict.items():
    print(f'{metric_name} for {metric}: {value}')


def get_gzipped_model_size(file):
  '''Returns size of gzipped model, in bytes.'''
  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)

#### Build Model

In [None]:
def model_builder():
  '''Returns a shallow CNN for training on the MNIST dataset'''

  keras = tf.keras

  # Define the model architecture.
  model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(28, 28)),
    keras.layers.Reshape(target_shape=(28, 28, 1)),
    keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10, activation='softmax')
  ])

  return model

#### Evaluate Model

In [32]:
def evaluate_tflite_model(filename, x_test, y_test):
  '''
  Measures the accuracy of a given TF Lite model and test set

  Args:
    filename (string) - filename of the model to load
    x_test (numpy array) - test images
    y_test (numpy array) - test labels

  Returns
    float showing the accuracy against the test set
  '''

  # Initialize the TF Lite Interpreter and allocate tensors
  interpreter = tf.lite.Interpreter(model_path=filename)
  interpreter.allocate_tensors()

  # Get input and output index
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Initialize empty predictions list
  prediction_digits = []

  # Run predictions on every image in the "test" dataset.
  for i, test_image in enumerate(x_test):
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == y_test).mean()

  return accuracy

####  Base Model Training
- Output: baseline_weights.h5
- Output: non_quantized.h5
- Output: non_quantized.tflite


In [5]:
# Create the baseline model
baseline_model = model_builder()

# Save the initial weights for use later
baseline_model.save_weights(FILE_WEIGHTS)

# Print the model summary
baseline_model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 reshape (Reshape)           (None, 28, 28, 1)         0         
                                                                 
 conv2d (Conv2D)             (None, 26, 26, 12)        120       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 13, 13, 12)       0         
 )                                                               
                                                                 
 flatten (Flatten)           (None, 2028)              0         
                                                                 
 dense (Dense)               (None, 10)                20290     
                                                                 
Total params: 20,410
Trainable params: 20,410
Non-trainable params: 0
____________________________________________________

In [6]:
# Setup the model for training
baseline_model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Train the model
baseline_model.fit(train_images, train_labels, epochs=1, shuffle=False)

# Get the baseline accuracy
_, ACCURACY['baseline Keras model'] = baseline_model.evaluate(test_images, test_labels)



In [7]:
# Save the Keras model
baseline_model.save(FILE_NON_QUANTIZED_H5, include_optimizer=False)

# Save and get the model size
MODEL_SIZE['baseline h5'] = os.path.getsize(FILE_NON_QUANTIZED_H5)

# Print records so far
print_metric(ACCURACY, "test accuracy")
print_metric(MODEL_SIZE, "model size in bytes")

test accuracy for baseline Keras model: 0.9602000117301941
model size in bytes for baseline h5: 98968


In [8]:
def convert_tflite(model, filename, quantize=False):
    '''
    Converts the model to TF Lite format and writes to a file

    Args:
        model (Keras model) - model to convert to TF Lite
        filename (string) - string to use when saving the file
        quantize (bool) - flag to indicate quantization

    Returns:
        None
    '''

    # Initialize the converter
    converter = tf.lite.TFLiteConverter.from_keras_model(model)

    # Set for quantization if flag is set to True
    if quantize:
        converter.optimizations = [tf.lite.Optimize.DEFAULT]

    # Convert the model
    tflite_model = converter.convert()

    # Save the model.
    with open(filename, 'wb') as f:
        f.write(tflite_model)

In [9]:
convert_tflite(baseline_model, FILE_NON_QUANTIZED_TFLITE)



INFO:tensorflow:Assets written to: C:\Users\clare\AppData\Local\Temp\tmphqyiagsh\assets


INFO:tensorflow:Assets written to: C:\Users\clare\AppData\Local\Temp\tmphqyiagsh\assets


In [10]:
MODEL_SIZE['non quantized tflite'] = os.path.getsize(FILE_NON_QUANTIZED_TFLITE)

print_metric(MODEL_SIZE, 'model size in bytes')

model size in bytes for baseline h5: 98968
model size in bytes for non quantized tflite: 85012


In [11]:
ACCURACY['non quantized tflite'] = evaluate_tflite_model(FILE_NON_QUANTIZED_TFLITE, test_images, test_labels)
print_metric(ACCURACY, 'test accuracy')

test accuracy for baseline Keras model: 0.9602000117301941
test accuracy for non quantized tflite: 0.9602


#### Post Training Quantization

In [12]:
convert_tflite(baseline_model, FILE_PT_QUANTIZED, quantize=True)

# Get the model size
MODEL_SIZE['post training quantized tflite'] = os.path.getsize(FILE_PT_QUANTIZED)

print_metric(MODEL_SIZE, 'model size')

ACCURACY['post training quantized tflite'] = evaluate_tflite_model(FILE_PT_QUANTIZED, test_images, test_labels)

print_metric(ACCURACY, 'test accuracy')



INFO:tensorflow:Assets written to: C:\Users\clare\AppData\Local\Temp\tmp15a6bwp5\assets


INFO:tensorflow:Assets written to: C:\Users\clare\AppData\Local\Temp\tmp15a6bwp5\assets


model size for baseline h5: 98968
model size for non quantized tflite: 85012
model size for post training quantized tflite: 24256
test accuracy for baseline Keras model: 0.9602000117301941
test accuracy for non quantized tflite: 0.9602
test accuracy for post training quantized tflite: 0.9604


#### Quantization Aware Training
- Better for model accuracy
- Lossy performance from lower precision can be solved
- Simulates low precision behaviour in forward pass, while backward pass remains the same
- Induces some quantization error while the optimizer tries to reduce it by adjusting params


- Output model is quantization model but not quantized(weights are float32 instead of int8)
- A slight difference in the model summary compared to the baseline model summary
- The total params count increased as expected because of the nodes added by the `quantize_model()` method
- The method inserts fake quant nodes in the model during training and model will learn to adapt with the loss of precision to get more accurate predictions


In [14]:
import tensorflow_model_optimization as tfmot

# method to quantize a Keras model
quantize_model = tfmot.quantization.keras.quantize_model

# Define the model architecture.
model_to_quantize = model_builder()

# Reinitialize weights with saved file
model_to_quantize.load_weights(FILE_WEIGHTS)

# Quantize the model
q_aware_model = quantize_model(model_to_quantize)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

q_aware_model.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 quantize_layer (QuantizeLay  (None, 28, 28)           3         
 er)                                                             
                                                                 
 quant_reshape_1 (QuantizeWr  (None, 28, 28, 1)        1         
 apperV2)                                                        
                                                                 
 quant_conv2d_1 (QuantizeWra  (None, 26, 26, 12)       147       
 pperV2)                                                         
                                                                 
 quant_max_pooling2d_1 (Quan  (None, 13, 13, 12)       1         
 tizeWrapperV2)                                                  
                                                                 
 quant_flatten_1 (QuantizeWr  (None, 2028)            

In [15]:
q_aware_model.fit(train_images, train_labels, epochs=1, shuffle=False)



<keras.callbacks.History at 0x21dfd35d190>

In [16]:
# Reinitialize the dictionary
ACCURACY = {}

# Get the accuracy of the quantization aware trained model (not yet quantized)
_, ACCURACY['quantization aware non-quantized'] = q_aware_model.evaluate(test_images, test_labels, verbose=0)
print_metric(ACCURACY, 'test accuracy')

# Convert and quantize the model.
convert_tflite(q_aware_model, FILE_QAT_QUANTIZED, quantize=True)

# Get the accuracy of the quantized model
ACCURACY['quantization aware quantized'] = evaluate_tflite_model(FILE_QAT_QUANTIZED, test_images, test_labels)
print_metric(ACCURACY, 'test accuracy')

test accuracy for quantization aware non-quantized: 0.9591000080108643




INFO:tensorflow:Assets written to: C:\Users\clare\AppData\Local\Temp\tmp7jpwuqtu\assets


INFO:tensorflow:Assets written to: C:\Users\clare\AppData\Local\Temp\tmp7jpwuqtu\assets


test accuracy for quantization aware non-quantized: 0.9591000080108643
test accuracy for quantization aware quantized: 0.9591


#### Pruning
- This process involves zeroing out insignificant (low magnitude) weights
- The intuition is these weights do not contribute as much to making predictions so you can remove them and get the same result
- Making the weights sparse helps in compressing the model more efficiently

In [18]:
# Get the pruning method
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set.

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define pruning schedule.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

# Pass in the trained baseline model
model_for_pruning = prune_low_magnitude(baseline_model, **pruning_params)

# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model_for_pruning.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_reshape  (None, 28, 28, 1)        1         
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_conv2d   (None, 26, 26, 12)       230       
 (PruneLowMagnitude)                                             
                                                                 
 prune_low_magnitude_max_poo  (None, 13, 13, 12)       1         
 ling2d (PruneLowMagnitude)                                      
                                                                 
 prune_low_magnitude_flatten  (None, 2028)             1         
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_dense (  (None, 10)               4

In [19]:
# Preview model weights
model_for_pruning.weights[1]

<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 12) dtype=float32, numpy=
array([[[[ 0.32331395,  0.37700206,  0.07158457,  0.15839554,
          -0.5928061 ,  0.16970588,  0.27302247,  0.05167996,
           0.1029333 , -0.29762572,  0.40509358,  0.266184  ]],

        [[-0.00203653,  0.22536668,  0.16937006,  0.38190857,
          -0.10673852,  0.1033941 ,  0.26569822,  0.21276605,
           0.03750268, -0.6709579 ,  0.6776479 ,  0.30801836]],

        [[-0.00610922,  0.02296782, -0.25250146,  0.1620334 ,
           0.42038572,  0.2590574 ,  0.02034328, -0.15450105,
           0.09606051, -0.4990606 ,  0.7326869 , -0.09349032]]],


       [[[ 0.16339053,  0.44928578,  0.21899992, -0.01333159,
          -0.6395615 ,  0.01518945,  0.3147946 ,  0.10961092,
           0.13365772,  0.03120475, -0.32336253, -0.00616928]],

        [[-0.03096809, -0.16775352,  0.24548844,  0.20115243,
          -0.0015689 ,  0.12037516, -0.05883881,  0.2972516 ,
           0.28075826,  0.27383828, -0.175120

In [20]:
# Callback to update pruning wrappers at each step
callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
]

# Train and prune the model
model_for_pruning.fit(train_images, train_labels,
                  epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x21da62bbac0>

In [21]:
# Preview model weights
model_for_pruning.weights[1]

<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 12) dtype=float32, numpy=
array([[[[ 0.        ,  0.7375559 , -0.        ,  0.        ,
          -0.88801235,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        , -0.        ]],

        [[ 0.        ,  0.        , -0.        ,  0.8130417 ,
          -0.        ,  0.        ,  0.        ,  0.        ,
           0.        , -1.2675618 ,  0.97723716,  0.7961744 ]],

        [[ 0.        ,  0.        , -0.        ,  0.        ,
           0.6954159 ,  0.        ,  0.        , -0.        ,
           0.        ,  0.        ,  1.0631871 , -0.        ]]],


       [[[ 0.        ,  0.8605922 , -0.        ,  0.        ,
          -1.0491787 ,  0.        ,  0.        , -0.        ,
           0.        ,  0.        ,  0.        , -0.        ]],

        [[ 0.        , -0.        , -0.        ,  0.        ,
          -0.        ,  0.        ,  0.        ,  0.8850673 ,
           0.        ,  0.        ,  0.      

In [22]:
# Remove pruning wrappers
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
model_for_export.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 reshape (Reshape)           (None, 28, 28, 1)         0         
                                                                 
 conv2d (Conv2D)             (None, 26, 26, 12)        120       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 13, 13, 12)       0         
 )                                                               
                                                                 
 flatten (Flatten)           (None, 2028)              0         
                                                                 
 dense (Dense)               (None, 10)                20290     
                                                                 
Total params: 20,410
Trainable params: 20,410
Non-trainable params: 0
____________________________________________________

In [23]:
# Preview model weights (index 1 earlier is now 0 because pruning wrappers were removed)
model_for_export.weights[0]

<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 12) dtype=float32, numpy=
array([[[[ 0.        ,  0.7375559 , -0.        ,  0.        ,
          -0.88801235,  0.        ,  0.        ,  0.        ,
           0.        ,  0.        ,  0.        , -0.        ]],

        [[ 0.        ,  0.        , -0.        ,  0.8130417 ,
          -0.        ,  0.        ,  0.        ,  0.        ,
           0.        , -1.2675618 ,  0.97723716,  0.7961744 ]],

        [[ 0.        ,  0.        , -0.        ,  0.        ,
           0.6954159 ,  0.        ,  0.        , -0.        ,
           0.        ,  0.        ,  1.0631871 , -0.        ]]],


       [[[ 0.        ,  0.8605922 , -0.        ,  0.        ,
          -1.0491787 ,  0.        ,  0.        , -0.        ,
           0.        ,  0.        ,  0.        , -0.        ]],

        [[ 0.        , -0.        , -0.        ,  0.        ,
          -0.        ,  0.        ,  0.        ,  0.8850673 ,
           0.        ,  0.        ,  0.      

In [24]:
# Save Keras model
model_for_export.save(FILE_PRUNED_MODEL_H5, include_optimizer=False)

# Get uncompressed model size of baseline and pruned models
MODEL_SIZE = {}
MODEL_SIZE['baseline h5'] = os.path.getsize(FILE_NON_QUANTIZED_H5)
MODEL_SIZE['pruned non quantized h5'] = os.path.getsize(FILE_PRUNED_MODEL_H5)

print_metric(MODEL_SIZE, 'model_size in bytes')





model_size in bytes for baseline h5: 98968
model_size in bytes for pruned non quantized h5: 98968


In [25]:
# Get compressed size of baseline and pruned models
MODEL_SIZE = {}
MODEL_SIZE['baseline h5'] = get_gzipped_model_size(FILE_NON_QUANTIZED_H5)
MODEL_SIZE['pruned non quantized h5'] = get_gzipped_model_size(FILE_PRUNED_MODEL_H5)

print_metric(MODEL_SIZE, "gzipped model size in bytes")

gzipped model size in bytes for baseline h5: 77994
gzipped model size in bytes for pruned non quantized h5: 25955


In [26]:
# Convert and quantize the pruned model.
pruned_quantized_tflite = convert_tflite(model_for_export, FILE_PRUNED_QUANTIZED_TFLITE, quantize=True)

# Compress and get the model size
MODEL_SIZE['pruned quantized tflite'] = get_gzipped_model_size(FILE_PRUNED_QUANTIZED_TFLITE)
print_metric(MODEL_SIZE, "gzipped model size in bytes")



INFO:tensorflow:Assets written to: C:\Users\clare\AppData\Local\Temp\tmpcj7yhq0v\assets


INFO:tensorflow:Assets written to: C:\Users\clare\AppData\Local\Temp\tmpcj7yhq0v\assets


gzipped model size in bytes for baseline h5: 77994
gzipped model size in bytes for pruned non quantized h5: 25955
gzipped model size in bytes for pruned quantized tflite: 8235


In [27]:
# Get accuracy of pruned Keras and TF Lite models
ACCURACY = {}

_, ACCURACY['pruned model h5'] = model_for_pruning.evaluate(test_images, test_labels)
ACCURACY['pruned and quantized tflite'] = evaluate_tflite_model(FILE_PRUNED_QUANTIZED_TFLITE, test_images, test_labels)

print_metric(ACCURACY, 'accuracy')

accuracy for pruned model h5: 0.9675999879837036
accuracy for pruned and quantized tflite: 0.9683
