In [1]:
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import tempfile
import os
import numpy as np


In [2]:
mnist = keras.datasets.mnist
(train_images, train_labels),(test_images, test_labels) = mnist.load_data()

train_images = train_images/255
test_images = test_images/255

model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(28,28)),
    keras.layers.Reshape(target_shape=(28,28,1)),
    keras.layers.Conv2D(filters = 12, kernel_size = (3,3), activation = 'relu'),
    keras.layers.MaxPooling2D(pool_size = (2,2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
]
)

model.compile(optimizer = 'adam', loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True), metrics = ['accuracy'])
model.fit(train_images, train_labels, epochs=4, validation_split=0.1)


Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<keras.callbacks.History at 0x26b1f0b4250>

In [3]:
_, baseline_model_accuracy = model.evaluate(
test_images, test_labels, verbose = 0)
print(baseline_model_accuracy)

0.9771999716758728


In [6]:
import tensorflow_model_optimization as tfmot

model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model)
model.compile(optimizer = 'adam', loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True), metrics = ['accuracy'])

model_for_pruning.summary()


Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_reshape  (None, 28, 28, 1)        1         
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_conv2d   (None, 26, 26, 12)       230       
 (PruneLowMagnitude)                                             
                                                                 
 prune_low_magnitude_max_poo  (None, 13, 13, 12)       1         
 ling2d (PruneLowMagnitude)                                      
                                                                 
 prune_low_magnitude_flatten  (None, 2028)             1         
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_dense (  (None, 10)               4

In [14]:
logdir = tempfile.mkdtemp()

callbacks = [ 
    
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir = logdir)  
]


In [18]:
opt = tf.keras.optimizers.SGD(learning_rate = 0.0001, momentum = 0.9, nesterov = True) 
model.compile(loss = 'categorical_crossentropy', optimizer = opt)