# Introduction

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/EfficientDL/book/blob/main/codelabs/Chapter-10-Tensorflow/Sparse_Model_Training_and_Inference.ipynb)

This is a toy colab to demonstrate pruning and sparse inference acceleration in TFLite to solve an image-classification problem with the CIFAR-10 dataset, and a vanilla CNN. We will use Tensorflow Model Optimization toolkit's pruning library to help with creating a sparse model. We will then accelerate this model using the [XNNPACK Delegate for TFLite](https://blog.tensorflow.org/2020/07/accelerating-tensorflow-lite-xnnpack-integration.html) for Android, but you should be able to see similar gains for any ARM device too. 

Currently the XNNPACK delegate supports a subset of operators, and for getting latency improvements there are [further restrictions](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md#sparse-inference) on the the model graph.

**Credit**: The following colab is based on the [original guide](https://www.tensorflow.org/model_optimization/guide/pruning/pruning_for_on_device_inference) authored by the TFMOT library authors with some changes that improve the model quality, simplify the flow a little bit, invoke the model and measure the latency, etc.

**Caveat**: Note that this notebook might get out-of-date as the support for sparse model training and inference gets better in Tensorflow, TFLite, and XNNPACK. There might be other alternatives as well that perform better than what we listed here.

In [None]:
 # Install the relevant packages.
 !pip install -q tensorflow
 !pip install -q tensorflow-model-optimization

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/238.9 KB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m235.5/238.9 KB[0m [31m16.8 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m238.9/238.9 KB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import tempfile
import tensorflow as tf
import numpy as np

from tensorflow import keras
import tensorflow_datasets as tfds
import tensorflow_model_optimization as tfmot

# Dataset Preparation

In [None]:
BATCH_SIZE=64

def normalize(image, label):
  """Normalize the input to be in [-1., 1.]."""
  return 2 * ((tf.cast(image, tf.float32) / 255.) - 0.5), label

def prepare_dataset(ds, buffer_size=None):
  """Helper function to create the dataset objects for train / eval."""
  ds = ds.map(normalize, 
              num_parallel_calls=tf.data.experimental.AUTOTUNE)
  ds = ds.cache()
  if buffer_size:
    ds = ds.shuffle(buffer_size)
  ds = ds.batch(BATCH_SIZE)
  ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
  return ds


# Load CIFAR10 dataset.
(ds_train, ds_val, ds_test), ds_info = tfds.load(
  'cifar10',
  split=['train[:90%]', 'train[90%:]', 'test'],
  as_supervised=True,
  with_info=True,
)

ds_train = prepare_dataset(
    ds_train,
    buffer_size=ds_info.splits['train'].num_examples)
ds_val = prepare_dataset(ds_val)
ds_test = prepare_dataset(ds_test)

# Define & train the dense model.

In [None]:
def create_dense_model():
  # Regularizer to prevent overfitting.
  reg = keras.regularizers.l2(1e-5)

  # Build the dense baseline model.
  dense_model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(32, 32, 3)),
    keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
    keras.layers.ZeroPadding2D(padding=1),
    keras.layers.Conv2D(
      filters=8,
      kernel_size=(3, 3),
      strides=(2, 2),
      padding='valid', 
      kernel_regularizer=reg),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),

    keras.layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.Conv2D(filters=32, kernel_size=(1, 1)),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),

    keras.layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.Conv2D(filters=32, kernel_size=(1, 1), kernel_regularizer=reg),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),  

    keras.layers.ZeroPadding2D(padding=1),
    keras.layers.DepthwiseConv2D(
        kernel_size=(3, 3), strides=(2, 2), padding='valid'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.Conv2D(filters=64, kernel_size=(1, 1), kernel_regularizer=reg),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.GlobalAveragePooling2D(keepdims=True),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
  ])
  return dense_model

In [None]:
# Compile and train the dense model for 10 epochs.
INIT_LR = 1e-3
DECAY_RATE = 0.95
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
      initial_learning_rate=INIT_LR, 
      decay_steps=int(ds_info.splits['train'].num_examples / BATCH_SIZE),
      decay_rate=DECAY_RATE)

dense_model=create_dense_model()

dense_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(lr_schedule),
    metrics=['accuracy'])

dense_model.fit(
  ds_train,
  epochs=20,
  validation_data=ds_val)

# Evaluate the dense model.
_, dense_model_accuracy = dense_model.evaluate(ds_test, verbose=0)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


# Prune the model.

We start by applying the `prune_low_magnitude` wrapper on the dense model, and then fine-tune the model while slowly pruning the lowest magnitude weights. The `PruneForLatencyOnXNNPACK` pruning_policy helps us create block / structured sparsity which can then be accelerated on device. Some of the important hyper-parameters here are:

1. Number of epochs to prune: If this is too small, then the model will not be fine-tuned properly, and useful weights might be pruned.

2. Initial and final sparsity: Choosing a high initial sparsity might mean that useful weights are suddenly pruned. Ideally you want a reasonable ramp-up of sparsity.


In [None]:
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Number of pruning epochs.
NUM_PRUNING_EPOCHS = 5

num_iterations_per_epoch = len(ds_train)
end_step =  num_iterations_per_epoch * NUM_PRUNING_EPOCHS

# Define parameters for pruning.
pruning_params = {
  'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
      initial_sparsity=0.0,
      final_sparsity=0.75,
      begin_step=0,
      end_step=end_step),
  'pruning_policy': tfmot.sparsity.keras.PruneForLatencyOnXNNPack()
}

# Try to apply pruning wrapper with pruning policy parameter.
try:
  model_for_pruning = prune_low_magnitude(dense_model, **pruning_params)
except ValueError as e:
  print(e)

## Check the model for pruning wrappers.

The pruning library adds in wrappers for pruning. Let's compare the output of the `summary()` methods for the two models.

In [None]:
dense_model.summary()

Model: "sequential_9"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 random_flip_9 (RandomFlip)  (None, 32, 32, 3)         0         
                                                                 
 zero_padding2d_18 (ZeroPadd  (None, 34, 34, 3)        0         
 ing2D)                                                          
                                                                 
 conv2d_36 (Conv2D)          (None, 16, 16, 8)         224       
                                                                 
 batch_normalization_63 (Bat  (None, 16, 16, 8)        32        
 chNormalization)                                                
                                                                 
 re_lu_63 (ReLU)             (None, 16, 16, 8)         0         
                                                                 
 depthwise_conv2d_27 (Depthw  (None, 16, 16, 8)       

In [None]:
model_for_pruning.summary()

Model: "sequential_9"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 random_flip_9 (RandomFlip)  (None, 32, 32, 3)         0         
                                                                 
 zero_padding2d_18 (ZeroPadd  (None, 34, 34, 3)        0         
 ing2D)                                                          
                                                                 
 conv2d_36 (Conv2D)          (None, 16, 16, 8)         224       
                                                                 
 batch_normalization_63 (Bat  (None, 16, 16, 8)        32        
 chNormalization)                                                
                                                                 
 re_lu_63 (ReLU)             (None, 16, 16, 8)         0         
                                                                 
 depthwise_conv2d_27 (Depthw  (None, 16, 16, 8)       

In [None]:
logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

INITIAL_PRUNING_LR = 1e-4
PRUNING_LR_DECAY_RATE = 0.95
pruning_lr_schedule = keras.optimizers.schedules.ExponentialDecay(
  initial_learning_rate=INITIAL_PRUNING_LR, 
  decay_steps=int(ds_info.splits['train'].num_examples / BATCH_SIZE),
  decay_rate=PRUNING_LR_DECAY_RATE)

model_for_pruning.compile(
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=keras.optimizers.Adam(pruning_lr_schedule),
  metrics=['accuracy'])

model_for_pruning.fit(
  ds_train,
  epochs=NUM_PRUNING_EPOCHS,
  validation_data=ds_val,
  callbacks=callbacks)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7fa49074a8e0>

## Verify the accuracy of the dense and pruned models.

We want to ensure that the pruned model's accuracy is almost the same as the dense model.

In [None]:
# Evaluate the dense model.
_, pruned_model_accuracy = model_for_pruning.evaluate(ds_test, verbose=0)

print('Dense model test accuracy:', dense_model_accuracy)
print('Pruned model test accuracy:', pruned_model_accuracy)

Dense model test accuracy: 0.6122000217437744
Pruned model test accuracy: 0.6054999828338623


## Verify that the layers are actually pruned.

The pruning API will not prune all the layers (mostly Conv2D layers in the current case), and it will not prune insignificant tensors such as biases. However, we do want to ensure that it does do something.

In [None]:
from tensorflow_model_optimization.python.core.sparsity.keras.pruning_wrapper import PruneLowMagnitude

for idx, layer in enumerate(model_for_pruning.layers):
  # We will check the `PruneLowMagnitude` wrapper layers which have an
  # associated `layer` object, which maps to the actual layer being pruned.
  if isinstance(layer, PruneLowMagnitude):
    print(f'Wrapper for layer: {layer.layer.name}')
    for weight in layer.layer.weights:
      num_weights = weight.numpy().size
      num_nonzero_weights = np.count_nonzero(weight.numpy())
      tensor_sparsity = num_nonzero_weights * 100. / num_weights
      print(f'|-- {weight.name}, Num Weights: {num_weights}, '
            f'Sparsity: {tensor_sparsity:.0f}%' )

Wrapper for layer: conv2d_37
|-- conv2d_37/kernel:0, Num Weights: 256, Sparsity: 25%
|-- conv2d_37/bias:0, Num Weights: 32, Sparsity: 100%
Wrapper for layer: conv2d_38
|-- conv2d_38/kernel:0, Num Weights: 1024, Sparsity: 25%
|-- conv2d_38/bias:0, Num Weights: 32, Sparsity: 100%
Wrapper for layer: conv2d_39
|-- conv2d_39/kernel:0, Num Weights: 2048, Sparsity: 25%
|-- conv2d_39/bias:0, Num Weights: 64, Sparsity: 100%


# Convert the models to TFLite.

We convert the dense model the same way, but for the pruned model we need to strip away the pruning wrappers instead by the pruning API, and then 

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(dense_model)
dense_tflite_model = converter.convert()

dense_tflite_file = 'dense.tflite'
with open(dense_tflite_file, 'wb') as f:
  f.write(dense_tflite_model)



In [None]:
pruned_model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

converter = tf.lite.TFLiteConverter.from_keras_model(pruned_model_for_export)
converter.optimizations = [tf.lite.Optimize.EXPERIMENTAL_SPARSITY]
pruned_tflite_model = converter.convert()

pruned_tflite_file = 'sparse.tflite'
with open(pruned_tflite_file, 'wb') as f:
  f.write(pruned_tflite_model)

# Comparing the size of the dense and sparse model upon compression.

In [None]:
import gzip
from pathlib import Path

# Create a compressed copy of input file and return its size
def get_compressed_size_in_kbs(file):
  path = Path(file)
  compressed_file = path.parent / (path.name + '.gz')

  with gzip.open(compressed_file, 'wb') as out:
    with open(file, 'rb') as inp:
      out.write(inp.read())
      
  return compressed_file.stat().st_size / 1024.

In [None]:
dense_model_size_kbs = get_compressed_size_in_kbs('dense.tflite')
sparse_model_size_kbs = get_compressed_size_in_kbs('sparse.tflite')

print(f'Dense model size: {dense_model_size_kbs:.2f} KB')
print(f'Sparse model size: {sparse_model_size_kbs:.2f} KB')
print(f'Compression: {(dense_model_size_kbs - sparse_model_size_kbs) * 100. / dense_model_size_kbs:.2f}%')

Dense model size: 13.08 KB
Sparse model size: 12.83 KB
Compression: 1.89%


# Final latency benchmarking of the sparse model.
We will use pre-built binaries of the `benchmark_model` binaries that Tensorflow provides, however you can also build them from source for your platform. In the colab, you can run the binary directly, but most likely you would notice than on x86 the model latency doesn't vary. However we can see the improvement in latency on Android, since the XNNPACK delegate allows acceleration on ARM.

In general, we highly recommend going through the [README for the benchmark tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark).

In [None]:
!wget -q https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/linux_x86-64_benchmark_model
!chmod +x linux_x86-64_benchmark_model

In [None]:
!./linux_x86-64_benchmark_model --graph=dense.tflite --use_xnnpack-true --warmup_runs=100 --num_runs=100000

STARTING!
Unconsumed cmdline flags: --use_xnnpack-true
Log parameter values verbosely: [0]
Min num runs: [100000]
Min warmup runs: [100]
Graph: [dense.tflite]
Loaded model dense.tflite
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
The input model file size (MB): 0.028156
Initialized session in 3.939ms.
Running benchmark for at least 100 iterations and at least 0.5 seconds but terminate if exceeding 150 seconds.
count=1762 first=443 curr=11447 min=66 max=33333 avg=283.205 std=1771

Running benchmark for at least 100000 iterations and at least 1 seconds but terminate if exceeding 150 seconds.
count=100000 first=84 curr=99 min=53 max=33503 avg=127.616 std=618

Inference timings in us: Init: 3939, First inference: 443, Warmup (avg): 283.205, Inference (avg): 127.616
Note: as the benchmark tool itself affects memory footprint, the following is only APPROXIMATE to the actual memory footprint of the model at runtime. Take the information at your discretion.
Memory footprint delta fr

In [None]:
!./linux_x86-64_benchmark_model --graph=sparse.tflite --use_xnnpack-true --warmup_runs=100 --num_runs=100000

STARTING!
Unconsumed cmdline flags: --use_xnnpack-true
Log parameter values verbosely: [0]
Min num runs: [100000]
Min warmup runs: [100]
Graph: [sparse.tflite]
Loaded model sparse.tflite
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
The input model file size (MB): 0.021136
Initialized session in 1.061ms.
Running benchmark for at least 100 iterations and at least 0.5 seconds but terminate if exceeding 150 seconds.
count=8426 first=177 curr=51 min=47 max=1046 avg=58.6821 std=25

Running benchmark for at least 100000 iterations and at least 1 seconds but terminate if exceeding 150 seconds.
count=100000 first=61 curr=80 min=43 max=7697 avg=59.3359 std=54

Inference timings in us: Init: 1061, First inference: 177, Warmup (avg): 58.6821, Inference (avg): 59.3359
Note: as the benchmark tool itself affects memory footprint, the following is only APPROXIMATE to the actual memory footprint of the model at runtime. Take the information at your discretion.
Memory footprint delta from the

The above two runs with the dense and sparse models might get very similar numbers, because as we said XNNPACK is optimized for ARM, and you might get a better performance improvement on device. So you can run the following commands to benchmark the model on your device.

## Download the models to your machine.
In order to benchmark on your device, let's download the dense and sparse models first to your machine. 

In [None]:
from google.colab import files

In [None]:
files.download('dense.tflite')

In [None]:
files.download('sparse.tflite')

## Download all the required binaries, and push them to the device.

For Android devices, we need the Android Debugger Bridge (ADB) that allows us to push binaries and files to it, and then run commands on it directly. You can go through the instructions [here](https://developer.android.com/studio/command-line/adb) to install ADB on your machine, if you don't have it already.

Once you have `adb` working on your machine, you need to download and push the right `benchmark_model` binary to the device. Note that you can build your own binary by downloading and cloning the tensorflow repository and then building the benchmark_model rule, or pick any other binary from [here](https://www.tensorflow.org/lite/performance/measurement#native_benchmark_binary).

```
wget https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_benchmark_model
```

```
adb push android_aarch64_benchmark_model /data/local/tmp/benchmark_model
```

Finally, push the models to the device as well.

```
adb push ~/Downloads/dense.tflite /data/local/tmp/
```

```
adb push ~/Downloads/sparse.tflite /data/local/tmp/
```

## Benchmark the models on the device.

We can now run the `benchmark_model` binary on the device and compare the results.

```
$ adb shell ./data/local/tmp/benchmark_model --graph=/data/local/tmp/dense.tflite --warmup_runs=5 --num_runs=2000
STARTING!
Log parameter values verbosely: [0]
Min num runs: [2000]
Min warmup runs: [5]
Graph: [/data/local/tmp/dense.tflite]
Loaded model /data/local/tmp/dense.tflite
INFO: Initialized TensorFlow Lite runtime.
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
VERBOSE: Replacing 12 node(s) with delegate (TfLiteXNNPackDelegate) node, yielding 1 partitions for the whole graph.
The input model file size (MB): 0.028008
Initialized session in 5.401ms.
Running benchmark for at least 5 iterations and at least 0.5 seconds but terminate if exceeding 150 seconds.
count=3380 first=729 curr=135 min=132 max=729 avg=145.822 std=33

Running benchmark for at least 2000 iterations and at least 1 seconds but terminate if exceeding 150 seconds.
count=7186 first=142 curr=140 min=132 max=1374 avg=137.33 std=18

Inference timings in us: Init: 5401, First inference: 729, Warmup (avg): 145.822, Inference (avg): 137.33
Note: as the benchmark tool itself affects memory footprint, the following is only APPROXIMATE to the actual memory footprint of the model at runtime. Take the information at your discretion.
Memory footprint delta from the start of the tool (MB): init=1.15625 overall=1.15625
```

```
$ adb shell ./data/local/tmp/benchmark_model --graph=/data/local/tmp/sparse.tflite --warmup_runs=5 --num_runs=2000
STARTING!
Log parameter values verbosely: [0]
Min num runs: [2000]
Min warmup runs: [5]
Graph: [/data/local/tmp/sparse.tflite]
Loaded model /data/local/tmp/sparse.tflite
INFO: Initialized TensorFlow Lite runtime.
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
VERBOSE: Replacing 15 node(s) with delegate (TfLiteXNNPackDelegate) node, yielding 1 partitions for the whole graph.
The input model file size (MB): 0.020972
Initialized session in 6.295ms.
Running benchmark for at least 5 iterations and at least 0.5 seconds but terminate if exceeding 150 seconds.
count=5683 first=463 curr=80 min=78 max=686 avg=85.9257 std=26

Running benchmark for at least 2000 iterations and at least 1 seconds but terminate if exceeding 150 seconds.
count=12159 first=87 curr=80 min=78 max=1540 avg=80.4727 std=20

Inference timings in us: Init: 6295, First inference: 463, Warmup (avg): 85.9257, Inference (avg): 80.4727
Note: as the benchmark tool itself affects memory footprint, the following is only APPROXIMATE to the actual memory footprint of the model at runtime. Take the information at your discretion.
Memory footprint delta from the start of the tool (MB): init=1.21875 overall=1.21875
```

## Conclusion
As you can see from the output above, the dense model takes ~ 137us for inference on an average, whereas the sparse model takes ~ 80us on an average. Hence we reduce the latency by an average of 42%, while keeping the accuracy the same. As the support for pruning and sparse inference improves in TFLite, and in other frameworks, you might see more sparse models being used in production.