# Pruning a CNN for MNIST using Tensorboard callbacks

## Prepare data

In [1]:
import tensorflow as tf
import tensorflow_model_optimization as tfmot
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime as date

In [2]:
tf.keras.backend.clear_session()

In [3]:
# obtain dataset and display size
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
print((x_train.shape, x_test.shape))

# normalize images
x_train = x_train / 255.0
x_test = x_test / 255.0

# one-hot encoding
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

((60000, 28, 28), (10000, 28, 28))


## Set up CNNs for pruning with various sparsities

In [4]:
# create 10 models for pruning (all layers will be pruned, but with different sparsities)
sparse_models = []

for s in range(10):
    sparsity = s/10
    pruning_schedule = tfmot.sparsity.keras.ConstantSparsity(sparsity, 0)

    # build architecture, compile, display summary
    cnn = tf.keras.models.Sequential([
        tf.keras.layers.InputLayer(input_shape=(28, 28)),
        tf.keras.layers.Reshape(target_shape=(28, 28, 1)),

        # convolution and pooling
        tf.keras.layers.Conv2D(filters=28, activation='relu', kernel_size=(3,3)),
        tf.keras.layers.MaxPooling2D((2,2)),
        tf.keras.layers.Conv2D(filters=28*2, activation='relu', kernel_size=(3,3)),
        tf.keras.layers.MaxPooling2D((2,2)),
        
        # dense
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(300, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    cnn.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

    # apply pruning schedule
    sparse_model = tfmot.sparsity.keras.prune_low_magnitude(cnn, pruning_schedule=pruning_schedule)
    sparse_models.append(sparse_model)

# print the last model
sparse_model.summary()



Model: "sequential_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_reshape_ (None, 28, 28, 1)         1         
_________________________________________________________________
prune_low_magnitude_conv2d_1 (None, 26, 26, 28)        534       
_________________________________________________________________
prune_low_magnitude_max_pool (None, 13, 13, 28)        1         
_________________________________________________________________
prune_low_magnitude_conv2d_1 (None, 11, 11, 56)        28282     
_________________________________________________________________
prune_low_magnitude_max_pool (None, 5, 5, 56)          1         
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 1400)              1         
_________________________________________________________________
prune_low_magnitude_dense_18 (None, 300)              

## Train

In [None]:
epochs = 20
histories = []

# train each network
for s in range(7,10):
  net = sparse_models[s]
  sparsity = s/10

  # log directory
  log_dir = "C:/Users/andre/neural/POLISHED/pruning/log/" + date.now().strftime("%Y-%m-%d/sparse-nets/sparsity=") + str(sparsity) + date.now().strftime("-%H-%M-%S")

  # callback to log pruning data
  callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir, update_freq=1)
  ]

  net.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'])

  # record histories
  histories.append(net.fit(x_train, y_train, callbacks=callbacks, epochs=epochs))

  print("Sparsity=" + str(sparsity) + ", Accuracy=" + str(net.evaluate(x_test, y_test)[1]))

## Compare test accuracy

In [None]:
acc = []
loss = []
for s in range(len(sparse_models)):
    a, l = sparse_models[s].evaluate(x_test, y_test)
    print("Sparsity="+str(s/10.0)+", Test Acc.="+str(a)+", Test Loss="+str(l))
    acc.append(a)
    loss.append(l)

## Compare sizes

In [34]:
# we are mainly interested in the size of weights files
cnn.save_weights("cnn_weights.h5")
os.path.getsize("cnn_weights.h5")

2875408

In [7]:
net.save_weights("sparsity_90_weights.h5")
os.path.getsize("sparsity_90_weights.h5")

5742916

In [35]:
cnn_for_pruning.save_weights("cnn_for_pruning_weights.h5")
os.path.getsize("cnn_for_pruning_weights.h5")

5742908

In [20]:
import tempfile

def get_gzipped_model_size(model):
    import os
    import zipfile

    _, keras_file = tempfile.mkstemp('.h5')
    model.save(keras_file, include_optimizer=False)

    _, zipped_file = tempfile.mkstemp('.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(keras_file)
    
    return os.path.getsize(zipped_file)

# strip_pruning is necessary to see the compression benefits of pruning
strip_cnn = tfmot.sparsity.keras.strip_pruning(cnn)
strip_cnn.save_weights("strip_cnn_weights.h5")
print("Regular model, stripped:", os.path.getsize("strip_cnn_weights.h5"))

strip_sparse = tfmot.sparsity.keras.strip_pruning(net)
strip_sparse.save_weights("strip_sparse_weights.h5")
print("Sparse model, stripped: ", os.path.getsize("strip_sparse_weights.h5"))

print("Size of gzipped regular model: %.2f bytes" % (get_gzipped_model_size(strip_cnn)))
print("Size of gzipped sparse model:  %.2f bytes" % (get_gzipped_model_size(strip_sparse)))

Regular model, stripped: 2875408
Sparse model, stripped:  2875408
Size of gzipped regular model: 544832.00 bytes
Size of gzipped sparse model:  544832.00 bytes
