In [1]:
# Install the python library for pruning
!pip install -q tensorflow-model-optimization

In [2]:
import numpy as np
import tempfile
import tensorflow as tf
from tensorflow import keras
from keras import losses, activations, models
from keras.layers import Dense, InputLayer,BatchNormalization, Dropout, Conv2D, Flatten, Activation
import tensorflow_model_optimization as tfmot
import numpy as np
import matplotlib.pyplot as plt

In [3]:
mnist = keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [4]:
# Normalize the input image so that each pixel value is between 0 to 1.
x_train = x_train / 255.0
x_test = x_test / 255.0

In [5]:
def get_model():
       inputs = keras.Input(shape=(28, 28, 1))
       x = Conv2D(32, kernel_size=(3, 3),
                     strides=(2, 2), padding="same",
                     use_bias=False)(inputs)
       x = BatchNormalization()(x)
       x = Activation("relu")(x)
       x = Conv2D(64, kernel_size=(3, 3),
                     strides=(2, 2), padding="same",
                     use_bias=False)(x)
       x = BatchNormalization()(x)
       x = Activation("relu")(x)
       x = Flatten()(x)
       x = Dropout(0.5)(x)
       outputs = Dense(10, activation="softmax")(x)

       model = keras.Model(inputs=inputs, outputs=outputs)
       
       opt = tf.keras.optimizers.Adam(0.001)
       model.compile(optimizer=opt, loss=losses.sparse_categorical_crossentropy, metrics=['accuracy'])
       return model


In [6]:
model_given = get_model()
model_given.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d (Conv2D)             (None, 14, 14, 32)        288       
                                                                 
 batch_normalization (BatchN  (None, 14, 14, 32)       128       
 ormalization)                                                   
                                                                 
 activation (Activation)     (None, 14, 14, 32)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 7, 7, 64)          18432     
                                                                 
 batch_normalization_1 (Batc  (None, 7, 7, 64)         256       
 hNormalization)                                             

In [9]:
def weight_summary(model, x_train, y_train):
    # Training the model to get the weights
    batch_size = 32
    epochs = 1
    model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
    sparsity_info = []
    def layer_weights(model):        
        # Get the weights from the layers
        for layer in model.layers:
            if is_hier_layer(layer):
                layer_weights(layer)
            else:
                weights = layer.get_weights()
                if len(weights) > 0:
                    total_weights = 0
                    non_zero_weights = 0
                    for w in weights:
                        total_weights += np.prod(w.shape)
                        non_zero_weights += np.count_nonzero(w)
                    sparsity = 1.0 - (non_zero_weights / total_weights)
                    sparsity_info.append(f'sparsity={sparsity:.4f} of layer={layer.name}')
        return   
    
    # Function to check if a layer is model or not
    def is_hier_layer(layer):
        "Finds if layer is actually a model instead of a single layer"
        return type(layer) in [models.Sequential, keras.Model]
    
    layer_weights(model)
    return sparsity_info

In [10]:
x = weight_summary(model_given,x_train,y_train)
for i in x:
    print(i)

sparsity=0.0000 of layer=conv2d
sparsity=0.0000 of layer=batch_normalization
sparsity=0.0000 of layer=conv2d_1
sparsity=0.0000 of layer=batch_normalization_1
sparsity=0.0000 of layer=dense


In [11]:
def prune_weight_summary(model_given, x_train, y_train):  
    prune_sparsity_info = []  # Moved sparsity_info inside the function
    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

    # Compute end step to finish pruning after 2 epochs.
    batch_size = 128
    epochs = 2
    validation_split = 0.1 # 10% of training set will be used for validation set.
    num_images = x_train.shape[0] * (1 - validation_split)
    end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

    # Define model for pruning.
    pruning_params = {
        'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.40,
                                                                final_sparsity=0.60,
                                                                begin_step=0,
                                                                end_step=end_step)
    }

    model_for_pruning = prune_low_magnitude(model_given, **pruning_params)

    # `prune_low_magnitude` requires a recompile.
    model_for_pruning.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                metrics=['accuracy'])

    model_for_pruning.summary()
    
    
    # Training the model to get the weights
    logdir = tempfile.mkdtemp()
    callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
    ]
    
    batch_size = 32
    epochs = 1
    model_for_pruning.fit(x_train, y_train, batch_size=batch_size, epochs=epochs,callbacks=callbacks)  
    
    def prune_layer_weights(model_pruned):
        
        for layer in model_pruned.layers:
            if is_hier_layer(layer):
                prune_layer_weights(layer)
            else:                
                # Get the weights from the layers
                weights = layer.get_weights()
                if len(weights) > 0:
                    total_weights = 0
                    non_zero_weights = 0
                    for w in weights:
                        total_weights += np.prod(w.shape)
                        non_zero_weights += np.count_nonzero(w)
                    sparsity = 1.0 - (non_zero_weights / total_weights)
                    prune_sparsity_info.append(f'sparsity={sparsity:.4f} of layer={layer.name}')

    # Function to check if a layer is model or not
    def is_hier_layer(layer):
        "Finds if layer is actually a model instead of a single layer"
        return type(layer) in [models.Sequential, keras.Model]

    prune_layer_weights(model_for_pruning)
    return prune_sparsity_info,model_for_pruning

In [12]:
pruned_data, model_for_pruning = prune_weight_summary(model_given,x_train,y_train)
print(pruned_data)

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 prune_low_magnitude_conv2d   (None, 14, 14, 32)       578       
 (PruneLowMagnitude)                                             
                                                                 
 prune_low_magnitude_batch_n  (None, 14, 14, 32)       129       
 ormalization (PruneLowMagni                                     
 tude)                                                           
                                                                 
 prune_low_magnitude_activat  (None, 14, 14, 32)       1  

In [13]:
def non-sparsity_probability(model_for_pruning,pruned_data):
    # Parse the sparsity values and layer names
    probability_pruned_values = {}
    prefix_name = model_for_pruning.name + "_"
    # non_zero_sparsity_values = []
    for line in pruned_data:
        parts = line.split()
        sparsity = float(parts[0].split('=')[1])
        layer_intended = parts[-1].split('=')[1]
        if sparsity != 1.0:
            for layer in model_for_pruning.layers:
                if layer.name == layer_intended:
                    weights = layer.get_weights()
                    if len(weights) > 0:
                        non_zero_weights = 0
                        total_weights = 0
                        for w in weights:
                            total_weights += np.prod(w.shape)
                            non_zero_weights += np.count_nonzero(w)
                        probability = non_zero_weights/total_weights
                        probability_pruned_values[prefix_name + layer.name] = probability
    return probability_pruned_values


In [16]:
non-sparsity_probability(model_for_pruning,pruned_data)


{'model_prune_low_magnitude_conv2d': 0.3993055555555556,
 'model_prune_low_magnitude_batch_normalization': 1.0,
 'model_prune_low_magnitude_conv2d_1': 0.4000108506944444,
 'model_prune_low_magnitude_batch_normalization_1': 1.0,
 'model_prune_low_magnitude_dense': 0.4002231431303793}