# Pruning the least important filters of LetNet trained on MNIST

In [6]:
from __future__ import print_function

import numpy as np
import keras

from keras.datasets import mnist
from keras.models import Sequential, Model
from keras.layers import Dense, Dropout, Flatten, Activation, Input
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K

from keras import layers
from keras.layers import Dense, Conv2D, MaxPool2D, Flatten


batch_size = 128
num_classes = 10
epochs = 5

# input image dimensions
img_rows, img_cols = 28, 28

# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)


model = Sequential()
model.add(Conv2D(20,
                 [3, 3],
                 input_shape=[28, 28, 1],
                 activation='relu',
                 name='conv_1'))
model.add(MaxPool2D())
model.add(Conv2D(50, [3, 3], activation='relu', name='conv_2'))
model.add(MaxPool2D(name='maxPool_2'))
model.add(layers.Permute((2, 1, 3)))
model.add(Flatten())
model.add(Dense(500, activation='relu', name='dense_1'))
model.add(Dense(10, activation='softmax', name='preds'))

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adam(),
              metrics=['accuracy'])

model.load_weights('model_mnist_weights.h5')
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
original_loss = model.evaluate(x_test, y_test, verbose=0)
print('original model loss:', original_loss, '\n')
model.summary()

x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
Test loss: 0.030818569360674293
Test accuracy: 0.9933
original model loss: [0.030818569360674293, 0.9933] 

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv_1 (Conv2D)              (None, 26, 26, 20)        200       
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 13, 13, 20)        0         
_________________________________________________________________
conv_2 (Conv2D)              (None, 11, 11, 50)        9050      
_________________________________________________________________
maxPool_2 (MaxPooling2D)     (None, 5, 5, 50)          0         
_________________________________________________________________
permute_2 (Permute)          (None, 5, 5, 50)          0         
_________________________________________________________________
flatten_2 (Flatten)        

In [7]:
######Filters of conv layer 1 ordered from most responsible to less responsible:
filters_conv1= {}
filters_conv1.update(
[(18, (0.0040999999999999925, 1)), (6, (0.0029000000000000137, 1)), (1, (0.0014999999999999458, 1)), (17, (0.0013999999999999568, 1)), (4, (0.0012999999999999678, 1)), (15, (0.0010999999999999899, 1)), (12, (0.0009000000000000119, 1)), (8, (0.0007999999999999119, 1)), (13, (0.0007999999999999119, 1)), (0, (0.0006999999999999229, 1)), (7, (0.0006999999999999229, 1)), (9, (0.0006999999999999229, 1)), (16, (0.0006999999999999229, 1)), (10, (0.0005999999999999339, 1)), (11, (0.0005999999999999339, 1)), (5, (0.00039999999999995595, 1)), (3, (0.00029999999999996696, 1)), (14, (0.00019999999999997797, 1)), (19, (0.00019999999999997797, 1)), (2, (9.999999999998899e-05, 1))]
)
######Filters of conv layer 2 ordered from most responsible to less responsible:
filters_conv2= {}
filters_conv2.update(
[(11, (0.0013999999999999568, 1)), (13, (0.0013999999999999568, 1)), (40, (0.0009000000000000119, 1)), (29, (0.0007999999999999119, 1)), (46, (0.0006999999999999229, 1)), (1, (0.0005999999999999339, 1)), (10, (0.0005999999999999339, 1)), (35, (0.0005999999999999339, 1)), (2, (0.0004999999999999449, 1)), (33, (0.0004999999999999449, 1)), (8, (0.00039999999999995595, 1)), (18, (0.00039999999999995595, 1)), (20, (0.00039999999999995595, 1)), (34, (0.00039999999999995595, 1)), (45, (0.00039999999999995595, 1)), (12, (0.00029999999999996696, 1)), (21, (0.00029999999999996696, 1)), (27, (0.00029999999999996696, 1)), (31, (0.00029999999999996696, 1)), (37, (0.00029999999999996696, 1)), (42, (0.00029999999999996696, 1)), (43, (0.00029999999999996696, 1)), (48, (0.00029999999999996696, 1)), (3, (0.00019999999999997797, 1)), (4, (0.00019999999999997797, 1)), (14, (0.00019999999999997797, 1)), (17, (0.00019999999999997797, 1)), (23, (0.00019999999999997797, 1)), (24, (0.00019999999999997797, 1)), (30, (0.00019999999999997797, 1)), (32, (0.00019999999999997797, 1)), (36, (0.00019999999999997797, 1)), (41, (0.00019999999999997797, 1)),  (5, (9.999999999998899e-05, 1)), (7, (9.999999999998899e-05, 1)), (19, (9.999999999998899e-05, 1)), (22, (9.999999999998899e-05, 1)), (26, (9.999999999998899e-05, 1)), (38, (9.999999999998899e-05, 1)), (39, (9.999999999998899e-05, 1)), (44, (9.999999999998899e-05, 1)), (47, (9.999999999998899e-05, 1)), (49, (9.999999999998899e-05, 1)), (6, (0.0, 1)), (9, (0.0, 1)), (15, (0.0, 1)), (25, (0.0, 1)), (28, (0.0, 1)), (0, (-9.999999999998899e-05, 1)), (16, (-9.999999999998899e-05, 1))]
)

In [8]:
from kerassurgeon import identify
from kerassurgeon.operations import delete_channels
from kerassurgeon import utils
from kerassurgeon import Surgeon

layer1 = model.get_layer('conv_1')
layer2 = model.get_layer('conv_2')
filters1= layer1.get_weights()[0]
filters2= layer2.get_weights()[0]
num_filters_Layer1 = len (filters1[0, 0, 0, :])
num_filters_Layer2 = len (filters2[0, 0, 0, :])

surgeon = Surgeon (model)

### You can choose the number of filters to purn at each layer
## the following filters are the most responsible filters to keep, the others are deleted 
channels_Resp_L1 = [18, 6, 1, 17, 4, 15, 12]
channelsToDelete_L1=[]
for i in range(20):
    if i not in channels_Resp_L1:
        channelsToDelete_L1.append(i)
surgeon.add_job ('delete_channels', layer1,
                channels=channelsToDelete_L1)
channels_Resp_L2 = [11, 13, 40, 29, 1, 10, 35, 2, 33, 8, 18, 20, 34, 45, 12, 21, 27, 31, 37, 42, 43, 48]

channelsToDelete_L2=[]
for i in range(50):
    if i not in channels_Resp_L2:
        channelsToDelete_L2.append(i)
surgeon.add_job ('delete_channels', layer2,
                 channels=channelsToDelete_L2)

model = surgeon.operate ()
model.compile (optimizer='adam',
                     loss='categorical_crossentropy',
                     metrics=['accuracy'])
loss_after = model.evaluate (x_test, y_test, verbose=0)
print ('after model loss:', loss_after, '\n')
model.summary()

Deleting 13/20 channels from layer: conv_1
Deleting 28/50 channels from layer: conv_2
after model loss: [0.21697785779237747, 0.9414] 

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv_1_input (InputLayer)    (None, 28, 28, 1)         0         
_________________________________________________________________
conv_1 (Conv2D)              (None, 26, 26, 7)         70        
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 multiple                  0         
_________________________________________________________________
conv_2 (Conv2D)              (None, 11, 11, 22)        1408      
_________________________________________________________________
maxPool_2 (MaxPooling2D)     multiple                  0         
_________________________________________________________________
permute_2 (Permute)          multiple                  0         
______