In [None]:
#uncomment to install
#pip install -q tensorflow-model-optimization

In [5]:
import tempfile
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
import tensorflow_model_optimization as tfmot

## Overview

Not all the features in the model are equally contributing, and at the same time not all the weights are contributing to the model performance. so why we cannot get rid off?

Sparse model are lighter to train and can be compressed, the weights equal to zero are not used in inference and thus the model is faster (which can be really importance if latency is an issue) without great loss of accuracy.

## training a model without and with pruning


we are using the classical MNIST to test. We will train a model and then prune after. 
We are using low magnitude pruning, this method is removing after each epoch the weights wwhich have the lower magnitude (a weight with low value is contributing less to the model, so we can theoretically eliminate without much harm to the model). Recall that a neuron if has a value of zero is not anymore contributing to the model. Therefore, low magnitude pruning is just setting the neuron with low weight value to zero (you use a lambda theresold). In general, this method allows to remove weight that are not contributing much without the risk to loose neuron that are important for the performance (i.e. here accuracy)

In [13]:
# keras provides MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# we normalize the input between 0 and 1
train_images = train_images / 255.0
test_images = test_images / 255.0

# We use here a simple sequential architechture (but it works also with a CNN)
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Flatten(),
  keras.layers.Dense(64, activation='relu'),
  keras.layers.Dense(10)
])


model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=4,
  validation_split=0.1,
)

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<keras.callbacks.History at 0x7fd86c786110>

In [6]:
#number of parameter in the model
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape (Reshape)            (None, 28, 28, 1)         0         
_________________________________________________________________
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 64)                50240     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                650       
Total params: 50,890
Trainable params: 50,890
Non-trainable params: 0
_________________________________________________________________


In [14]:
_, baseline_model_accuracy = model.evaluate(
test_images, test_labels, verbose=0)

print('before pruning test accuracy:', baseline_model_accuracy)

#we are saving in the temporary file
_, keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
print('Saved pre-pruned model to:', keras_file)

before pruning test accuracy: 0.9711999893188477
Saved pre-pruned model to: /tmp/tmpsb71m2l0.h5


In [15]:
import tensorflow_model_optimization as tfmot

#you need this for pruning, is a wrapper to allow to prune our weights
#we will use low magnitude pruning
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# we will do just two epochs
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set. 

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# we start inserting 50 % of sparsity, until 80%
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

model_for_pruning = prune_low_magnitude(model, **pruning_params)

# we need to recompile the file
model_for_pruning.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model_for_pruning.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_reshape_ (None, 28, 28, 1)         1         
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 784)               1         
_________________________________________________________________
prune_low_magnitude_dense_2  (None, 64)                100418    
_________________________________________________________________
prune_low_magnitude_dense_3  (None, 10)                1292      
Total params: 101,712
Trainable params: 50,890
Non-trainable params: 50,822
_________________________________________________________________


In [17]:
logdir = tempfile.mkdtemp()

#updatepruningstep is necessary to propagate the optimizer activities
callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(), #necessary to call for the pruning
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir), #saving info
]

model_for_pruning.fit(train_images, train_labels,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7fd68ef13b50>

In [18]:
_, model_for_pruning_accuracy = model_for_pruning.evaluate(
   test_images, test_labels, verbose=0)

print('pre-pruning test accuracy:', baseline_model_accuracy) 
print('post-pruning test accuracy:', model_for_pruning_accuracy)

pre-pruning test accuracy: 0.9711999893188477
post-pruning test accuracy: 0.9671000242233276


In [11]:
#strip pruning allow to make model dense after the pruning step
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

_, pruned_keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)

Saved pruned Keras model to: /tmp/tmpnsgvfnva.h5
