# Network Pruning
inspired by https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras?hl=en#build_the_mnist_model

In [18]:
%load_ext tensorboard
import tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [19]:
import tensorflow as tf
import numpy as np
import pandas as pd

import tempfile
import zipfile
import os

print(tf.__version__)

2.0.0


# Load and Preprocess Data

In [20]:
num_classes = 10

# input image dimensions
img_rows, img_cols = 28, 28

(x_train, y_train), (x_test,y_test) = tf.keras.datasets.mnist.load_data()

print(x_train.shape)

x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)

#normalize
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

x_train /= 255
x_test /= 255

print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)

(60000, 28, 28)
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples


# Initialize and Compile Two Identical Models
The models' architecture are LeNet-like (relu instead of sigmoid and softmax instead of ???)

In [21]:
l = tf.keras.layers

model = tf.keras.Sequential([
    l.Conv2D(6, 5, activation='relu', input_shape=input_shape),
    l.MaxPooling2D((2,2,), (2,2)),
    l.Conv2D(16, 5, activation='relu'),
    l.MaxPooling2D((2,2,), (2,2)),
    l.Flatten(),
    l.Dense(120, activation='relu'),
    l.Dense(84, activation='relu'),
    l.Dense(num_classes, activation='softmax')
])

model.compile(
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy']
)

model.summary()

Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_8 (Conv2D)            (None, 24, 24, 6)         156       
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 12, 12, 6)         0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 8, 8, 16)          2416      
_________________________________________________________________
max_pooling2d_9 (MaxPooling2 (None, 4, 4, 16)          0         
_________________________________________________________________
flatten_4 (Flatten)          (None, 256)               0         
_________________________________________________________________
dense_12 (Dense)             (None, 120)               30840     
_________________________________________________________________
dense_13 (Dense)             (None, 84)               

In [22]:
pruned_model = tf.keras.Sequential([
    l.Conv2D(6, 5, activation='relu', input_shape=input_shape),
    l.MaxPooling2D((2,2,), (2,2)),
    l.Conv2D(16, 5, activation='relu'),
    l.MaxPooling2D((2,2,), (2,2)),
    l.Flatten(),
    l.Dense(120, activation='relu'),
    l.Dense(84, activation='relu'),
    l.Dense(num_classes, activation='softmax')
])

pruned_model.compile(
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy']
)

#pruned_model.set_weights(model.get_weights())

In [23]:
logdir = tempfile.mkdtemp()
#print('Writing training logs to ' + logdir)

In [24]:
#%tensorboard --logdir={logdir}


In [25]:
batch_size = 32
epochs = 10
callbacks = [tf.keras.callbacks.TensorBoard(log_dir=logdir, profile_batch=0)]

model.fit(x_train, 
          y_train,
          batch_size=batch_size,
          epochs=epochs,
          callbacks=callbacks,
          validation_data=(x_test, y_test)
         )
score = model.evaluate(x_test, y_test)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

Train on 60000 samples, validate on 10000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


Test loss: 1.4734451370239259
Test accuracy: 0.988


In [26]:
def prune_weights(model, pruning_ratio):
    weights = model.get_weights()
    weights_to_prune = model.get_weights()
    for index, weight in enumerate(weights):
        flat_weights = weight.flatten()
        flat_weights_df = pd.DataFrame(flat_weights)
        no_of_weights_to_prune = int(len(flat_weights)*pruning_ratio)
        indices_to_delete = flat_weights_df.abs().values.argsort(0)[:no_of_weights_to_prune]
        for idx_to_delete in indices_to_delete:
            flat_weights[idx_to_delete] = 0

        weights_reshaped = flat_weights.reshape(weight.shape)
        weights_to_prune[index] = weights_reshaped
    return weights_to_prune
    

In [27]:
pruning_ratios = [x/10 for x in range(10)]
losses = []
accuracies = []


for pruning_ratio in pruning_ratios:
    pruned_weights = prune_weights(model, pruning_ratio)
    pruned_model.set_weights(pruned_weights)
    res = pruned_model.evaluate(x_test, y_test)
    losses.append(res[0])
    accuracies.append(res[1])





















In [28]:
scores = pd.DataFrame(zip(pruning_ratios,losses, accuracies), columns=['Pruning Ratio', 'loss', 'accuracy'])

print (scores)

   Pruning Ratio      loss  accuracy
0            0.0  1.473445    0.9880
1            0.1  1.473628    0.9874
2            0.2  1.473976    0.9873
3            0.3  1.476857    0.9844
4            0.4  1.482400    0.9788
5            0.5  1.496366    0.9659
6            0.6  1.548883    0.9233
7            0.7  1.649888    0.8318
8            0.8  1.975161    0.6173
9            0.9  2.280443    0.1764


In [17]:
scores = pd.DataFrame(zip(pruning_ratios,losses, accuracies), columns=['Pruning Ratio', 'loss', 'accuracy'])

print (scores)

   Pruning Ratio      loss  accuracy
0            0.0  1.479938    0.9812
1            0.1  1.480971    0.9799
2            0.2  1.484322    0.9769
3            0.3  1.486142    0.9746
4            0.4  1.492618    0.9682
5            0.5  1.511035    0.9500
6            0.6  1.511988    0.9514
7            0.7  1.626180    0.8486
8            0.8  2.133850    0.3781
9            0.9  2.289310    0.1819


In [29]:
pruned_model.get_weights()

[array([[[[ 0.        ,  0.        ,  0.        ,  0.        ,
            0.        ,  0.        ]],
 
         [[ 0.        ,  0.        ,  0.        ,  0.        ,
            0.        ,  0.        ]],
 
         [[ 0.        ,  0.        ,  0.        ,  0.        ,
            0.        ,  0.        ]],
 
         [[ 0.        ,  0.        ,  0.        ,  0.        ,
            0.        ,  0.        ]],
 
         [[ 0.43808812,  0.        ,  0.        ,  0.        ,
            0.        ,  0.        ]]],
 
 
        [[[ 0.        ,  0.        ,  0.        ,  0.        ,
            0.        ,  0.        ]],
 
         [[ 0.        ,  0.        ,  0.        ,  0.        ,
            0.        ,  0.        ]],
 
         [[ 0.        ,  0.        ,  0.        ,  0.        ,
            0.        ,  0.        ]],
 
         [[ 0.        ,  0.        ,  0.        ,  0.        ,
           -0.5508476 ,  0.        ]],
 
         [[ 0.49325785,  0.        ,  0.        ,  0.        

In [30]:
model.get_weights()

[array([[[[ 1.79423764e-01,  2.14725375e-01, -3.39581579e-01,
            3.05215627e-01,  1.30352497e-01,  1.28028125e-01]],
 
         [[ 1.20444492e-01,  3.25849563e-01, -2.61012912e-01,
            3.89078468e-01, -1.95800751e-01, -2.79138863e-01]],
 
         [[-7.57256895e-02,  1.79318666e-01,  2.55610552e-02,
            2.10629836e-01, -2.68746883e-01, -3.33399355e-01]],
 
         [[ 2.42547885e-01,  2.22637936e-01,  2.24815428e-01,
            2.20081270e-01, -2.82113612e-01,  5.01002334e-02]],
 
         [[ 4.38088119e-01,  1.54954836e-01,  5.89292590e-03,
            3.97938713e-02,  1.86353009e-02,  2.25025732e-02]]],
 
 
        [[[-1.56244874e-01,  9.72465500e-02, -3.80078405e-01,
           -1.36664346e-01,  2.01182410e-01, -3.78449298e-02]],
 
         [[ 1.02278844e-01,  2.30645537e-01, -2.13751897e-01,
           -3.83684076e-02, -9.00138617e-02,  1.92369238e-01]],
 
         [[-1.90683797e-01,  1.24930948e-01,  2.82237697e-02,
            6.72951564e-02, -3.99327874