In [1]:
import tensorflow as tf

In [68]:
from tensorflow import keras
import numpy as np
import os
import zipfile
import tempfile

In [3]:
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [4]:
train_images = train_images / 255.0
test_images = test_images / 255.0

In [98]:
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(16, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(2, 2),
  keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(2, 2),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

In [99]:
model.compile(optimizer='adam',  loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])

In [100]:
model.summary()

Model: "sequential_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape_8 (Reshape)          (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 26, 26, 16)        160       
_________________________________________________________________
max_pooling2d_12 (MaxPooling (None, 13, 13, 16)        0         
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 11, 11, 32)        4640      
_________________________________________________________________
max_pooling2d_13 (MaxPooling (None, 5, 5, 32)          0         
_________________________________________________________________
flatten_9 (Flatten)          (None, 800)               0         
_________________________________________________________________
dense_9 (Dense)              (None, 10)               

In [101]:
model.fit(train_images,train_labels,epochs=4,validation_split=0.2)

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<tensorflow.python.keras.callbacks.History at 0x7f380471ee80>

In [102]:
_, base_test = model.evaluate(test_images, test_labels)
print('base Test set accuracy:', base_test)

base Test set accuracy: 0.9873999953269958


In [103]:
_, keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model, keras_file)
print('base model:', keras_file)

base model: /tmp/tmp5j90ybxu.h5


#Pruning Begins!

In [82]:
!pip install tensorflow_model_optimization



In [28]:
import tensorflow_model_optimization as tfmot

In [105]:
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

batch_size = 128
epochs = 2
validation_split = 0.2 

In [106]:
train_images.shape

(60000, 28, 28)

In [114]:
num_images = train_images.shape[0] * (1 - validation_split) # 48,000 num_images
end_step = int(num_images / batch_size) * epochs

### Pruning with 50% zeroes in weights

In [122]:
pruning_params_1 = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

pruning_model_1 = prune_low_magnitude(model, **pruning_params_1)


In [123]:
pruning_model_1.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
      
pruning_model_1.summary()

Model: "sequential_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_reshape_ (None, 28, 28, 1)         1         
_________________________________________________________________
prune_low_magnitude_conv2d_1 (None, 26, 26, 16)        306       
_________________________________________________________________
prune_low_magnitude_max_pool (None, 13, 13, 16)        1         
_________________________________________________________________
prune_low_magnitude_conv2d_1 (None, 11, 11, 32)        9250      
_________________________________________________________________
prune_low_magnitude_max_pool (None, 5, 5, 32)          1         
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 800)               1         
_________________________________________________________________
prune_low_magnitude_dense_9  (None, 10)               

In [124]:
logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]
  
pruning_model_1.fit(train_images, train_labels,batch_size=batch_size, epochs=epochs, validation_split=validation_split,callbacks=callbacks)

Epoch 1/2
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x7f3801094278>

In [128]:
_, pruning_model_1_accuracy = pruning_model_1.evaluate(test_images, test_labels)

print('base test accuracy:', base_test) 
print('Pruned model 1 test accuracy:', pruning_model_1_accuracy)

base test accuracy: 0.9873999953269958
Pruned model 1 test accuracy: 0.9850999712944031


In [130]:
model_for_export = tfmot.sparsity.keras.strip_pruning(pruning_model_1)

In [131]:
_, pruned_model_1_keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model_for_export, pruned_model_1_keras_file)
print('Saved pruned model 1 to:', pruned_model_1_keras_file)

Saved pruned model 1 to: /tmp/tmpw0g2yc09.h5


In [95]:
def zip_model(file):
  temp, zip_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zip_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zip_file)

In [132]:
print("Size of base model: %.2f bytes" % (zip_model(keras_file)))
print("Size of pruned model: %.2f bytes" % (zip_model(pruned_model_1_keras_file)))

Size of base model: 147786.00 bytes
Size of pruned model: 17432.00 bytes


### Pruning with 30% zero weights

In [133]:
pruning_params_2 = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.30,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

pruning_model_2 = prune_low_magnitude(model, **pruning_params_2)


In [136]:
pruning_model_2.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
      
pruning_model_2.summary()

Model: "sequential_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_reshape_ (None, 28, 28, 1)         1         
_________________________________________________________________
prune_low_magnitude_conv2d_1 (None, 26, 26, 16)        306       
_________________________________________________________________
prune_low_magnitude_max_pool (None, 13, 13, 16)        1         
_________________________________________________________________
prune_low_magnitude_conv2d_1 (None, 11, 11, 32)        9250      
_________________________________________________________________
prune_low_magnitude_max_pool (None, 5, 5, 32)          1         
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 800)               1         
_________________________________________________________________
prune_low_magnitude_dense_9  (None, 10)               

In [137]:
logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]
  
pruning_model_2.fit(train_images, train_labels,batch_size=batch_size, epochs=epochs, validation_split=validation_split,callbacks=callbacks)

Epoch 1/2
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x7f380313ac88>

In [138]:
_, pruning_model_2_accuracy = pruning_model_2.evaluate(test_images, test_labels)

print('base test accuracy:', base_test) 
print('Pruned test accuracy:', pruning_model_2_accuracy)

base test accuracy: 0.9873999953269958
Pruned test accuracy: 0.9855999946594238


In [139]:
model_for_export = tfmot.sparsity.keras.strip_pruning(pruning_model_2)

In [140]:
_, pruned_model_2_keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model_for_export, pruned_model_2_keras_file)
print('Saved pruned Keras model to:', pruned_model_2_keras_file)

Saved pruned Keras model to: /tmp/tmpcyh8g_b6.h5


In [141]:
print("Size of base model: %.2f bytes" % (zip_model(keras_file)))
print("Size of pruned model_1: %.2f bytes" % (zip_model(pruned_model_1_keras_file)))
print("Size of pruned model_2: %.2f bytes" % (zip_model(pruned_model_2_keras_file)))


Size of base model: 147786.00 bytes
Size of pruned model_1: 17432.00 bytes
Size of pruned model_2: 17361.00 bytes
