In [1]:
import tensorflow as tf
tf.enable_eager_execution()

import tempfile
import zipfile
import os

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
batch_size = 128
num_classes = 10
epochs = 15

# input image dimensions
img_rows, img_cols, img_dim = 32, 32, 3

# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

input_shape = (img_rows, img_cols, img_dim)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
#y_train = tf.keras.utils.to_categorical(y_train, num_classes)
#y_test = tf.keras.utils.to_categorical(y_test, num_classes)

x_train shape: (50000, 32, 32, 3)
50000 train samples
10000 test samples


## Train a CIFAR-10 model without pruning

In [3]:
l = tf.keras.layers

model = tf.keras.Sequential([
    l.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
    l.MaxPooling2D((2, 2)),
    l.Conv2D(64, (3, 3), activation='relu'),
    l.MaxPooling2D((2, 2)),
    l.Conv2D(64, (3, 3), activation='relu'),
    l.Flatten(),
    l.Dense(64, activation='relu'),
    l.Dense(num_classes)
])

model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 30, 30, 32)        896       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 15, 15, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 13, 13, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 6, 6, 64)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 4, 4, 64)          36928     
_________________________________________________________________
flatten (Flatten)            (None, 1024)              0         
_________________________________________________________________
dense (Dense)                (None, 64)                6

In [4]:
logdir = tempfile.mkdtemp()
print('Writing training logs to ' + logdir)

Writing training logs to /tmp/tmpdtkc0n9_


In [5]:
model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])

history = model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              verbose=1,
              validation_data=(x_test, y_test))

Train on 50000 samples, validate on 10000 samples
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


In [6]:
# Backend agnostic way to save/restore models
_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
tf.keras.models.save_model(model, keras_file, include_optimizer=False)

Saving model to:  /tmp/tmpf5bj9lhx.h5


## Train a pruned CIFAR-10

In [31]:
from tensorflow_model_optimization.sparsity import keras as sparsity

In [32]:
import numpy as np

epochs = 15
num_train_samples = x_train.shape[0]
end_step = np.ceil(1.0 * num_train_samples / batch_size).astype(np.int32) * epochs
print('End step: ' + str(end_step))

End step: 5865


In [33]:
pruning_params = {
      'pruning_schedule': sparsity.PolynomialDecay(initial_sparsity=0.25,
                                                   final_sparsity=0.90,
                                                   begin_step=2000,
                                                   end_step=end_step,
                                                   frequency=100)
}

pruned_model = tf.keras.Sequential([
    sparsity.prune_low_magnitude(
        l.Conv2D(32, (3, 3), padding='same', activation='relu'),
        input_shape=input_shape,
        **pruning_params),
    l.MaxPooling2D((2, 2)),
    sparsity.prune_low_magnitude(l.Conv2D(64, (3, 3), padding='same', activation='relu'), **pruning_params),
    l.MaxPooling2D((2, 2)),
    sparsity.prune_low_magnitude(l.Conv2D(64, (3, 3), padding='same', activation='relu'), **pruning_params),
    l.Flatten(),
    sparsity.prune_low_magnitude(l.Dense(64, activation='relu'), **pruning_params),
    sparsity.prune_low_magnitude(l.Dense(num_classes, activation='softmax'), **pruning_params)
])

pruned_model.summary()

Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_conv2d_1 (None, 32, 32, 32)        1762      
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 16, 16, 32)        0         
_________________________________________________________________
prune_low_magnitude_conv2d_1 (None, 16, 16, 64)        36930     
_________________________________________________________________
max_pooling2d_9 (MaxPooling2 (None, 8, 8, 64)          0         
_________________________________________________________________
prune_low_magnitude_conv2d_1 (None, 8, 8, 64)          73794     
_________________________________________________________________
flatten_4 (Flatten)          (None, 4096)              0         
_________________________________________________________________
prune_low_magnitude_dense_8  (None, 64)               

In [34]:
logdir = tempfile.mkdtemp()
print('Writing training logs to ' + logdir)

Writing training logs to /tmp/tmpvj87vrde


In [35]:
pruned_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])

# Add a pruning step callback to peg the pruning step to the optimizer's
# step. Also add a callback to add pruning summaries to tensorboard
callbacks = [
    sparsity.UpdatePruningStep(),
    sparsity.PruningSummaries(log_dir=logdir, profile_batch=0)
]

pruned_model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          callbacks=callbacks,
          validation_data=(x_test, y_test))

score = pruned_model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

Train on 50000 samples, validate on 10000 samples
Epoch 1/15
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_13/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_13/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_14/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_14/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_8/mask:0/sparsity is illegal; using prune_low_magnitude_dense_8/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_9/mask:0/sparsity is illegal; using prune_low_magnitude_dense_9/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_12/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_12/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_13/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_13/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_

INFO:tensorflow:Summary name prune_low_magnitude_dense_8/threshold:0/threshold is illegal; using prune_low_magnitude_dense_8/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_9/threshold:0/threshold is illegal; using prune_low_magnitude_dense_9/threshold_0/threshold instead.
Epoch 6/15
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_13/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_13/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_14/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_14/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_8/mask:0/sparsity is illegal; using prune_low_magnitude_dense_8/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_9/mask:0/sparsity is illegal; using prune_low_magnitude_dense_9/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_12/threshold:0/threshold is illegal; 

INFO:tensorflow:Summary name prune_low_magnitude_conv2d_12/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_12/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_13/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_13/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_14/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_14/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_8/threshold:0/threshold is illegal; using prune_low_magnitude_dense_8/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_9/threshold:0/threshold is illegal; using prune_low_magnitude_dense_9/threshold_0/threshold instead.
Epoch 11/15
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_13/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_13/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude

INFO:tensorflow:Summary name prune_low_magnitude_conv2d_14/mask:0/sparsity is illegal; using prune_low_magnitude_conv2d_14/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_8/mask:0/sparsity is illegal; using prune_low_magnitude_dense_8/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_9/mask:0/sparsity is illegal; using prune_low_magnitude_dense_9/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_12/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_12/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_13/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_13/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_14/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_14/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_8/threshold:0/threshold is il

In [36]:
final_model = sparsity.strip_pruning(pruned_model)
final_model.summary()

Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_12 (Conv2D)           (None, 32, 32, 32)        896       
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 16, 16, 32)        0         
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 16, 16, 64)        18496     
_________________________________________________________________
max_pooling2d_9 (MaxPooling2 (None, 8, 8, 64)          0         
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 8, 8, 64)          36928     
_________________________________________________________________
flatten_4 (Flatten)          (None, 4096)              0         
_________________________________________________________________
dense_8 (Dense)              (None, 64)               

In [37]:
_, pruned_keras_file = tempfile.mkstemp('.h5')
print('Saving pruned model to: ', pruned_keras_file)

# No need to save the optimizer with the graph for serving.
tf.keras.models.save_model(final_model, pruned_keras_file, include_optimizer=False)

Saving pruned model to:  /tmp/tmply8tdzz5.h5


## Compare the size of the unpruned vs. pruned model after compression

In [38]:
_, zip1 = tempfile.mkstemp('.zip') 
with zipfile.ZipFile(zip1, 'w', compression=zipfile.ZIP_DEFLATED) as f:
  f.write(keras_file)
print("Size of the unpruned model before compression: %.2f Mb" % 
      (os.path.getsize(keras_file) / float(2**20)))
print("Size of the unpruned model after compression: %.2f Mb" % 
      (os.path.getsize(zip1) / float(2**20)))

_, zip2 = tempfile.mkstemp('.zip') 
with zipfile.ZipFile(zip2, 'w', compression=zipfile.ZIP_DEFLATED) as f:
  f.write(pruned_keras_file)
print("Size of the pruned model before compression: %.2f Mb" % 
      (os.path.getsize(pruned_keras_file) / float(2**20)))
print("Size of the pruned model after compression: %.2f Mb" % 
      (os.path.getsize(zip2) / float(2**20)))

Size of the unpruned model before compression: 0.49 Mb
Size of the unpruned model after compression: 0.44 Mb
Size of the pruned model before compression: 1.24 Mb
Size of the pruned model after compression: 0.24 Mb


## Convert to TensorFlow Lite