In [20]:
import tensorflow as tf
tf.enable_eager_execution()

import tempfile
import zipfile
import os

In [21]:
batch_size = 128
num_classes = 10
epochs = 10

# input image dimensions
img_rows, img_cols = 28, 28

# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

if tf.keras.backend.image_data_format() == 'channels_first':
  x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
  x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
  input_shape = (1, img_rows, img_cols)
else:
  x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
  x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
  input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
#y_train = tf.keras.utils.to_categorical(y_train, num_classes)
#y_test = tf.keras.utils.to_categorical(y_test, num_classes)

x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples


## Train a fashion_mnist model without pruning

In [26]:
l = tf.keras.layers

model = tf.keras.Sequential([
    l.Conv2D(
        32, 5, padding='same', activation='relu', input_shape=input_shape),
    l.MaxPooling2D((2, 2), (2, 2), padding='same'),
    l.BatchNormalization(),
    l.Flatten(input_shape=input_shape),
    l.Dense(128, activation='relu'),
    l.Dense(num_classes)
])

model.summary()

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 28, 28, 32)        832       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 32)        0         
_________________________________________________________________
batch_normalization (BatchNo (None, 14, 14, 32)        128       
_________________________________________________________________
flatten_3 (Flatten)          (None, 6272)              0         
_________________________________________________________________
dense_6 (Dense)              (None, 128)               802944    
_________________________________________________________________
dense_7 (Dense)              (None, 10)                1290      
Total params: 805,194
Trainable params: 805,130
Non-trainable params: 64
_______________________________________________

In [27]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [28]:
model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

Train on 60000 samples, validate on 10000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Test loss: 0.4387332788825035
Test accuracy: 0.9068


In [29]:
# Backend agnostic way to save/restore models
_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
tf.keras.models.save_model(model, keras_file, include_optimizer=False)

Saving model to:  /tmp/tmpps3xyecv.h5


## Train a pruned fashin_mnist

In [47]:
from tensorflow_model_optimization.sparsity import keras as sparsity

In [48]:
import numpy as np

epochs = 10
num_train_samples = x_train.shape[0]
end_step = np.ceil(1.0 * num_train_samples / batch_size).astype(np.int32) * epochs
print('End step: ' + str(end_step))

End step: 4690


In [49]:
pruning_params = {
      'pruning_schedule': sparsity.PolynomialDecay(initial_sparsity=0.25,
                                                   final_sparsity=0.90,
                                                   begin_step=2000,
                                                   end_step=end_step,
                                                   frequency=100)
}

pruned_model = tf.keras.Sequential([
    sparsity.prune_low_magnitude(
        l.Conv2D(32, 5, padding='same', activation='relu'),
        input_shape=input_shape,
        **pruning_params),
    l.MaxPooling2D((2, 2), (2, 2), padding='same'),
    l.BatchNormalization(),
    l.Flatten(),
    sparsity.prune_low_magnitude(l.Dense(128, activation='relu'),
                                 **pruning_params),
    l.Dropout(0.4),
    sparsity.prune_low_magnitude(l.Dense(num_classes, activation='softmax'),
                                 **pruning_params)
])

pruned_model.summary()

Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_conv2d_3 (None, 28, 28, 32)        1634      
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 14, 14, 32)        0         
_________________________________________________________________
batch_normalization_3 (Batch (None, 14, 14, 32)        128       
_________________________________________________________________
flatten_6 (Flatten)          (None, 6272)              0         
_________________________________________________________________
prune_low_magnitude_dense_12 (None, 128)               1605762   
_________________________________________________________________
dropout_2 (Dropout)          (None, 128)               0         
_________________________________________________________________
prune_low_magnitude_dense_13 (None, 10)               

In [50]:
logdir = tempfile.mkdtemp()
print('Writing training logs to ' + logdir)

Writing training logs to /tmp/tmppezqd_oe


In [51]:
pruned_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])

# Add a pruning step callback to peg the pruning step to the optimizer's
# step. Also add a callback to add pruning summaries to tensorboard
callbacks = [
    sparsity.UpdatePruningStep(),
    sparsity.PruningSummaries(log_dir=logdir, profile_batch=0)
]

pruned_model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          callbacks=callbacks,
          validation_data=(x_test, y_test))

score = pruned_model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

Train on 60000 samples, validate on 10000 samples
Epoch 1/10
INFO:tensorflow:Summary name prune_low_magnitude_dense_12/mask:0/sparsity is illegal; using prune_low_magnitude_dense_12/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_13/mask:0/sparsity is illegal; using prune_low_magnitude_dense_13/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_12/threshold:0/threshold is illegal; using prune_low_magnitude_dense_12/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_13/threshold:0/threshold is illegal; using prune_low_magnitude_dense_13/threshold_0/threshold instead.
Epoch 2/10
INFO:tensorflow:Summary name prune_low_magnitude_dense_12/mask:0/sparsity is illegal; using prune_low_magnitude_dense_12/mask_0/sparsity instead.
INFO:tensorflow:

INFO:tensorflow:Summary name prune_low_magnitude_dense_12/mask:0/sparsity is illegal; using prune_low_magnitude_dense_12/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_13/mask:0/sparsity is illegal; using prune_low_magnitude_dense_13/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_conv2d_3/threshold:0/threshold is illegal; using prune_low_magnitude_conv2d_3/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_12/threshold:0/threshold is illegal; using prune_low_magnitude_dense_12/threshold_0/threshold instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_13/threshold:0/threshold is illegal; using prune_low_magnitude_dense_13/threshold_0/threshold instead.
Epoch 9/10
INFO:tensorflow:Summary name prune_low_magnitude_dense_12/mask:0/sparsity is illegal; using prune_low_magnitude_dense_12/mask_0/sparsity instead.
INFO:tensorflow:Summary name prune_low_magnitude_dense_13/mask:0/sparsity is 

In [52]:
final_model = sparsity.strip_pruning(pruned_model)
final_model.summary()

Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_3 (Conv2D)            (None, 28, 28, 32)        832       
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 14, 14, 32)        0         
_________________________________________________________________
batch_normalization_3 (Batch (None, 14, 14, 32)        128       
_________________________________________________________________
flatten_6 (Flatten)          (None, 6272)              0         
_________________________________________________________________
dense_12 (Dense)             (None, 128)               802944    
_________________________________________________________________
dropout_2 (Dropout)          (None, 128)               0         
_________________________________________________________________
dense_13 (Dense)             (None, 10)               

In [53]:
_, pruned_keras_file = tempfile.mkstemp('.h5')
print('Saving pruned model to: ', pruned_keras_file)

# No need to save the optimizer with the graph for serving.
tf.keras.models.save_model(final_model, pruned_keras_file, include_optimizer=False)

Saving pruned model to:  /tmp/tmpnu8ph786.h5


## Compare the size of the unpruned vs. pruned model after compression

In [54]:
_, zip1 = tempfile.mkstemp('.zip') 
with zipfile.ZipFile(zip1, 'w', compression=zipfile.ZIP_DEFLATED) as f:
  f.write(keras_file)
print("Size of the unpruned model before compression: %.2f Mb" % 
      (os.path.getsize(keras_file) / float(2**20)))
print("Size of the unpruned model after compression: %.2f Mb" % 
      (os.path.getsize(zip1) / float(2**20)))

_, zip2 = tempfile.mkstemp('.zip') 
with zipfile.ZipFile(zip2, 'w', compression=zipfile.ZIP_DEFLATED) as f:
  f.write(pruned_keras_file)
print("Size of the pruned model before compression: %.2f Mb" % 
      (os.path.getsize(pruned_keras_file) / float(2**20)))
print("Size of the pruned model after compression: %.2f Mb" % 
      (os.path.getsize(zip2) / float(2**20)))

Size of the unpruned model before compression: 3.09 Mb
Size of the unpruned model after compression: 2.86 Mb
Size of the pruned model before compression: 3.09 Mb
Size of the pruned model after compression: 0.60 Mb
