In [19]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
import numpy as np
import pandas as pd
from tensorflow.keras import layers
from tqdm import tqdm
import matplotlib.pyplot as plt

import foolbox as fb

#tf.compat.v1.enable_eager_execution()
#tf.keras.backend.clear_session()  # For easy reset of notebook state.

In [2]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

# Weights

In [3]:
weights = {
    # 5x5 conv, 1 input, 6 outputs
    'weights_conv_1': tf.Variable(tf.random.normal([5, 5, 1, 6])),
    # 5x5 conv, 6 inputs, 16 outputs
    'weights_conv_2': tf.Variable(tf.random.normal([5, 5, 6, 16])),
    #5x5 conv as in paper, 16 inputs, 120 outputs
    'weights_conv_3': tf.Variable(tf.random.normal([1, 1, 16, 120])),
    # fully connected, 5*5*16 inputs, 120 outputs
    'weights_dense_1': tf.Variable(tf.random.normal([5*5*16, 120])),
    # fully connected, 120 inputs, 84 outputs
    'weights_dense_2': tf.Variable(tf.random.normal([120, 84])),
    # 84 inputs, 10 outputs (class prediction)
    'weights_dense_3': tf.Variable(tf.random.normal([84, 10])),
}

masks = {
    # 5x5 conv, 1 input, 6 outputs
    'mask_conv_1': tf.Variable(tf.ones([5, 5, 1, 6]), trainable=False),
    # 5x5 conv, 6 inputs, 16 outputs
    'mask_conv_2': tf.Variable(tf.ones([5, 5, 6, 16]), trainable=False),
    #5x5 conv as in paper, 16 inputs, 120 outputs
    'mask_conv_3': tf.Variable(tf.ones([1, 1, 16, 120]), trainable=False),
    # fully connected, 5*5*16 inputs, 120 outputs
    'mask_dense_1': tf.Variable(tf.ones([5*5*16, 120]), trainable=False),
    # fully connected, 120 inputs, 84 outputs
    'mask_dense_2': tf.Variable(tf.ones([120, 84]), trainable=False),
    # 84 inputs, 10 outputs (class prediction)
    'mask_dense_3': tf.Variable(tf.ones([84, 10]), trainable=False),
}

biases = {
    #output depth
    'bias_conv_1': tf.Variable(tf.random.normal([6])),
    'bias_conv_2': tf.Variable(tf.random.normal([16])),
    'bias_dense_1': tf.Variable(tf.random.normal([120])),
    'bias_dense_2': tf.Variable(tf.random.normal([84])),
    'bias_dense_3': tf.Variable(tf.random.normal([10])),
}



# Wrappers

In [4]:
#conv2D with bias and relu activation

class CustomConvLayer(layers.Layer):

    def __init__(self, weights, mask, biases, strides, padding='SAME'):
        
        super(CustomConvLayer, self).__init__()
        self.w = weights
        self.m = mask
        self.b = biases
        self.s = strides
        self.p = padding

        
    def call(self, inputs):
        #print('inputs',inputs)
        #print('weights', self.w)
        #print('masks', self.m)
        #print('weights * masks',tf.multiply(self.w, self.m))
        x = tf.nn.conv2d(inputs, tf.multiply(self.w, self.m), strides=[1, self.s, self.s, 1], padding=self.p,)# data_format='NCHW')
        #print('x', x)
        #print('bias', self.b)
        x = tf.nn.bias_add(x, self.b,)# 'NC...')
        #print('x', x)
        return tf.nn.tanh(x)
        

#Average Pooling Layer
class CustomPoolLayer(layers.Layer):
    
    def __init__(self, k=2, padding='valid'):#padding='VALID'):
        super(CustomPoolLayer, self).__init__()
        self.k = k
        self.p = padding
    
    def call(self, inputs):
#        return tf.keras.layers.AveragePooling2D(pool_size=(self.k, self.k), strides=None, padding=self.p, data_format='channels_first')(inputs)
        return tf.nn.avg_pool2d(inputs, ksize=[1, self.k, self.k,1], strides=[1, self.k, self.k, 1], padding=self.p,)# data_format='NCHW')
    
#Dense Layer with Bias
class CustomDenseLayer(layers.Layer):
    
    def __init__(self, weights, mask, bias, activation = 'tanh'):
        super(CustomDenseLayer, self).__init__()
        self.w = weights
        self.b = bias
        self.a = activation
        self.m = mask
        
    def call(self, inputs):
        #print('dense w',self.w)
        #print('dense i',inputs)
        x = tf.matmul(inputs, tf.multiply(self.w, self.m))
        #print('bias ',self.b)
        x = tf.nn.bias_add(x, self.b)
        if self.a == 'tanh':
            return tf.nn.tanh(x)
        if self.a == 'softmax':
            return tf.nn.softmax(x)


# Create Model

In [5]:
class CustomConvModel(tf.keras.Model):
    def __init__(self):
        super(CustomConvModel, self).__init__()
        self.conv1 = CustomConvLayer(weights['weights_conv_1'], masks['mask_conv_1'], biases['bias_conv_1'], 1, 'SAME')#'VALID')
        self.maxpool1 = CustomPoolLayer(k=2, padding='SAME')
        self.conv2 = CustomConvLayer(weights['weights_conv_2'], masks['mask_conv_2'], biases['bias_conv_2'], 1, 'VALID')
        self.maxpool2 = CustomPoolLayer(k=2, padding='VALID')
        #self.conv3 = CustomConvLayer(weights['weights_conv_3'], masks['mask_conv_3'], biases['bias_dense_1'], 1, 'VALID')
        self.dense1 = CustomDenseLayer(weights['weights_dense_1'], masks['mask_dense_1'], biases['bias_dense_1'], 'tanh')
        self.dense2 = CustomDenseLayer(weights['weights_dense_2'], masks['mask_dense_2'], biases['bias_dense_2'], 'tanh')
        self.dense3 = CustomDenseLayer(weights['weights_dense_3'], masks['mask_dense_3'], biases['bias_dense_3'], 'softmax')
        
    def call(self, inputs):
        #print('input shape', inputs.shape)
        x = tf.reshape(inputs, shape=[-1,28, 28, 1])
        #print('after reshape',x.shape)
        x = self.conv1(x)
        #print('after conv1', x.shape)
        x = self.maxpool1(x)
        #print('after pool1',x.shape)
        x = self.conv2(x)
        #print('after conv2',x.shape)
        x = self.maxpool2(x)
        #print('yo',x.shape)
        #x = layers.Flatten()(x)
        #print('after pool2',x.shape)
        #x = self.conv3(x)
        
        #print('after conv3',x.shape)
        x = layers.Flatten()(x)
        #print('after flatten',x.shape)
        x = self.dense1(x)
        x = self.dense2(x)
        #print(x.shape)
        x =  self.dense3(x)
        #print(x.shape)
        return x
        

In [31]:
model = CustomConvModel()

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) ,
              metrics=['accuracy'],
              experimental_run_tf_function=False
             )

model.fit(x=x_train,
          y=y_train,
          batch_size=64,
          epochs=10,
          validation_data=(x_test, y_test),
         )
model.save('./saved-models/cnn-structural-pruning-pipeline')

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: ./saved-models/cnn-structural-pruning-pipeline/assets


In [29]:
def structural_prune_conv_layers(model, pruning_ratio):
    all_layers = model.get_weights()
    layers_to_prune = [0,3]
    for layer_to_prune in layers_to_prune:
        mask = structural_prune_channels(all_layers[layer_to_prune], pruning_ratio)
        all_layers[layer_to_prune+2] = mask
    model.set_weights(all_layers)
    return model

def structural_prune_channels(layer, pruning_ratio):
    vals = []
    iohw_layer = convert_from_hwio_to_iohw(layer)
    converted_shape = iohw_layer.shape
    no_of_channels = converted_shape[0]*converted_shape[1]
    no_of_channels_to_prune = int(np.round(pruning_ratio * no_of_channels))
    channels = tf.reshape(iohw_layer, (no_of_channels,converted_shape[2],converted_shape[3]))
    mask = tf.ones(channels.shape)
    for channel in channels:
        vals.append(tf.math.reduce_sum(tf.math.abs(channel)))
    channels_to_prune = np.argsort(vals)[:no_of_channels_to_prune]
    mask= mask.numpy()
    for channel_to_prune in channels_to_prune:
        mask[channel_to_prune] = tf.zeros([converted_shape[2],converted_shape[3]])

    reshaped_mask = tf.reshape(mask, converted_shape)
    hwio_mask = convert_from_iohw_to_hwio(reshaped_mask)
    return hwio_mask

In [8]:
def convert_from_hwio_to_iohw(weights_nchw):
    return tf.transpose(weights_nchw, [2, 3, 0, 1])



def convert_from_iohw_to_hwio(weights_nhwc):
    return tf.transpose(weights_nhwc, [2, 3, 0, 1])

In [9]:
def prune_conv_layers(pruning_ratio):
    layer_to_prune = [0, 3]
    pruned_weights = model.get_weights()
    
    for layer in layer_to_prune:
        converted_weights = convert_from_hwio_to_iohw(model.get_weights()[layer])
        converted_mask = convert_from_hwio_to_iohw(model.get_weights()[layer + 2]).numpy()
        for input_index, input_layer in enumerate(converted_weights):

            for kernel_index, kernel in enumerate(input_layer):
                dims = kernel.shape
                flat_weights = kernel.numpy().flatten()
                flat_masks = converted_mask[input_index][kernel_index].flatten()
                flat_weights_df = pd.DataFrame(flat_weights)
                flat_mask_df = pd.DataFrame(flat_masks)
                no_of_weights_to_prune = int(len(flat_weights)*pruning_ratio)
                #print(no_of_weights_to_prune)
                indices_to_delete = flat_weights_df.abs().values.argsort(0)[:no_of_weights_to_prune]
                for idx_to_delete in indices_to_delete:
                    flat_masks[idx_to_delete] = 0

                converted_mask[input_index][kernel_index] = flat_masks.reshape(dims)
        back_converted_mask = convert_from_iohw_to_hwio(converted_mask)
        pruned_weights[layer+2] = back_converted_mask
    
    return pruned_weights


In [33]:
pruning_ratios = [0.0, .3, .5, .7, 0.8, 0.9, .95, .97, .98, .99]
accuracies = []
pgd_success_rates = []
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)
for index,pruning_ratio in tqdm(enumerate(pruning_ratios)):
    model = tf.keras.models.load_model('./saved-models/cnn-structural-pruning-pipeline')
    

    for idx in range(index+1):
        #print(pruning_ratios[idx])
        pruned_model = structural_prune_conv_layers(model, pruning_ratios[idx])
        if idx != index:
        # train x epochs before doing next pruning step
            pruned_model.fit(x=x_train,
                         y=y_train,
                         batch_size=64,
                         epochs=3,
                         callbacks=[callback],
                         validation_data=(x_test, y_test),
                        )
        if idx == index:        
        # train to convergence after final pruning
            pruned_model.fit(x=x_train,
                             y=y_train,
                             batch_size=64,
                             epochs=100,
                             callbacks=[callback],
                             validation_data=(x_test, y_test),
                            )
    pgd_success_rates.append(pgd_attack(pruned_model))
    accuracies.append(pruned_model.evaluate(x_test, y_test, verbose=0))

0it [00:00, ?it/s]

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100


1it [05:36, 336.20s/it]

Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100


2it [09:20, 302.61s/it]

Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100


3it [15:43, 326.80s/it]

Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100


4it [21:39, 335.62s/it]

Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100


5it [28:00, 349.24s/it]

Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100


6it [37:35, 416.81s/it]

Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100


7it [44:17, 412.45s/it]

Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100


8it [52:01, 428.00s/it]

Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100


9it [1:01:36, 472.09s/it]

Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100


10it [1:10:42, 424.28s/it]


In [34]:

for index,ratio in enumerate(pruning_ratios):
    print('pruning ratio: ', 
          ratio, 
          'accuracy after fine tuning: ',
          accuracies[index][1])

for index,ratio in enumerate(pruning_ratios):
    print('pruning ratio: ', 
          ratio, 
          'accuracy after fine tuning: ',
          pgd_success_rates[index])

pruning ratio:  0.0 accuracy after fine tuning:  0.9781000018119812
pruning ratio:  0.3 accuracy after fine tuning:  0.9769999980926514
pruning ratio:  0.5 accuracy after fine tuning:  0.9768999814987183
pruning ratio:  0.7 accuracy after fine tuning:  0.9702000021934509
pruning ratio:  0.8 accuracy after fine tuning:  0.9433000087738037
pruning ratio:  0.9 accuracy after fine tuning:  0.9419000148773193
pruning ratio:  0.95 accuracy after fine tuning:  0.0957999974489212
pruning ratio:  0.97 accuracy after fine tuning:  0.0982000008225441
pruning ratio:  0.98 accuracy after fine tuning:  0.0957999974489212
pruning ratio:  0.99 accuracy after fine tuning:  0.0957999974489212
pruning ratio:  0.0 accuracy after fine tuning:  0.096
pruning ratio:  0.3 accuracy after fine tuning:  0.096
pruning ratio:  0.5 accuracy after fine tuning:  0.132
pruning ratio:  0.7 accuracy after fine tuning:  0.116
pruning ratio:  0.8 accuracy after fine tuning:  0.294
pruning ratio:  0.9 accuracy after fine t

In [12]:
def prune_weights(model, pruning_ratio):
    weights = model.get_weights()
    weights_to_prune = model.get_weights()
    for index, weight in enumerate(weights):
        
        if (index == 9) or (index == 12) :
            print(weight.shape)
            print(index)
            flat_weights = weight.flatten()
            flat_weights_to_prune = weights_to_prune[index+2].flatten()
            #print (flat_weights_to_prune.shape, flat_weights.shape)
            flat_weights_df = pd.DataFrame(flat_weights)
            flat_weights_to_prune_df = pd.DataFrame(flat_weights_to_prune)
            no_of_weights_to_prune = int(len(flat_weights)*pruning_ratio)
            print(len(flat_weights))
            print('no of weights',no_of_weights_to_prune)
            print('weights to prune shape', flat_weights_to_prune.shape)
            indices_to_delete = flat_weights_df.abs().values.argsort(0)[:no_of_weights_to_prune]
            for idx_to_delete in indices_to_delete:
                flat_weights_to_prune[idx_to_delete] = 0
            dims = weights_to_prune[index+2].shape
            weights_reshaped = flat_weights_to_prune.reshape(dims)
            weights_to_prune[index+2] = weights_reshaped
    #print(weights_to_prune)
    return weights_to_prune

In [13]:
model.evaluate(x_test, y_test)



[2.311707019805908, 0.11349999904632568]

In [14]:
model.get_weights()[0].shape

(5, 5, 1, 6)

In [15]:
print(convert_from_hwio_to_iohw(model.get_weights()[0]).shape)

(1, 6, 5, 5)


In [16]:
def convert_from_hwio_to_iohw(weights_nchw):
    return tf.transpose(weights_nchw, [2, 3, 0, 1])



def convert_from_iohw_to_hwio(weights_nhwc):
    return tf.transpose(weights_nhwc, [2, 3, 0, 1])

In [20]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

x = tf.convert_to_tensor(x_train[:500].reshape(500,28*28))
y = tf.convert_to_tensor([y_train[:500]])[0];

In [28]:
def pgd_attack(model_to_attack):
    fmodel = fb.models.TensorFlowModel(model_to_attack, bounds=(0,1))
    attack = fb.attacks.LinfProjectedGradientDescentAttack()
    adversarials = attack(
        fmodel,
        x,
        y,
        epsilons=[8/255]
    )
    return np.count_nonzero(adversarials[2])/500

IndexError: list assignment index out of range