In [90]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
import numpy as np
import pandas as pd
from tensorflow.keras import layers
from tqdm import tqdm
import matplotlib.pyplot as plt


#tf.compat.v1.enable_eager_execution()
#tf.keras.backend.clear_session()  # For easy reset of notebook state.

In [111]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

# Weights

In [112]:
weights = {
    # 5x5 conv, 1 input, 6 outputs
    'weights_conv_1': tf.Variable(tf.random.normal([5, 5, 1, 6])),
    # 5x5 conv, 6 inputs, 16 outputs
    'weights_conv_2': tf.Variable(tf.random.normal([5, 5, 6, 16])),
    #5x5 conv as in paper, 16 inputs, 120 outputs
    'weights_conv_3': tf.Variable(tf.random.normal([1, 1, 16, 120])),
    # fully connected, 5*5*16 inputs, 120 outputs
    'weights_dense_1': tf.Variable(tf.random.normal([5*5*16, 120])),
    # fully connected, 120 inputs, 84 outputs
    'weights_dense_2': tf.Variable(tf.random.normal([120, 84])),
    # 84 inputs, 10 outputs (class prediction)
    'weights_dense_3': tf.Variable(tf.random.normal([84, 10])),
}

masks = {
    # 5x5 conv, 1 input, 6 outputs
    'mask_conv_1': tf.Variable(tf.ones([5, 5, 1, 6]), trainable=False),
    # 5x5 conv, 6 inputs, 16 outputs
    'mask_conv_2': tf.Variable(tf.ones([5, 5, 6, 16]), trainable=False),
    #5x5 conv as in paper, 16 inputs, 120 outputs
    'mask_conv_3': tf.Variable(tf.ones([1, 1, 16, 120]), trainable=False),
    # fully connected, 5*5*16 inputs, 120 outputs
    'mask_dense_1': tf.Variable(tf.ones([5*5*16, 120]), trainable=False),
    # fully connected, 120 inputs, 84 outputs
    'mask_dense_2': tf.Variable(tf.ones([120, 84]), trainable=False),
    # 84 inputs, 10 outputs (class prediction)
    'mask_dense_3': tf.Variable(tf.ones([84, 10]), trainable=False),
}

biases = {
    #output depth
    'bias_conv_1': tf.Variable(tf.random.normal([6])),
    'bias_conv_2': tf.Variable(tf.random.normal([16])),
    'bias_dense_1': tf.Variable(tf.random.normal([120])),
    'bias_dense_2': tf.Variable(tf.random.normal([84])),
    'bias_dense_3': tf.Variable(tf.random.normal([10])),
}



# Wrappers

In [113]:
#conv2D with bias and relu activation

class CustomConvLayer(layers.Layer):

    def __init__(self, weights, mask, biases, strides, padding='SAME'):
        
        super(CustomConvLayer, self).__init__()
        self.w = weights
        self.m = mask
        self.b = biases
        self.s = strides
        self.p = padding

        
    def call(self, inputs):
        print('inputs',inputs)
        #print('weights', self.w)
        #print('masks', self.m)
        print('weights * masks',tf.multiply(self.w, self.m))
        x = tf.nn.conv2d(inputs, tf.multiply(self.w, self.m), strides=[1, self.s, self.s, 1], padding=self.p,)# data_format='NCHW')
        print('x', x)
        #print('bias', self.b)
        x = tf.nn.bias_add(x, self.b,)# 'NC...')
        #print('x', x)
        return tf.nn.tanh(x)
        

#Average Pooling Layer
class CustomPoolLayer(layers.Layer):
    
    def __init__(self, k=2, padding='valid'):#padding='VALID'):
        super(CustomPoolLayer, self).__init__()
        self.k = k
        self.p = padding
    
    def call(self, inputs):
#        return tf.keras.layers.AveragePooling2D(pool_size=(self.k, self.k), strides=None, padding=self.p, data_format='channels_first')(inputs)
        return tf.nn.avg_pool2d(inputs, ksize=[1, self.k, self.k,1], strides=[1, self.k, self.k, 1], padding=self.p,)# data_format='NCHW')
    
#Dense Layer with Bias
class CustomDenseLayer(layers.Layer):
    
    def __init__(self, weights, mask, bias, activation = 'tanh'):
        super(CustomDenseLayer, self).__init__()
        self.w = weights
        self.b = bias
        self.a = activation
        self.m = mask
        
    def call(self, inputs):
        #print('dense w',self.w)
        #print('dense i',inputs)
        x = tf.matmul(inputs, tf.multiply(self.w, self.m))
        print('bias ',self.b)
        x = tf.nn.bias_add(x, self.b)
        if self.a == 'tanh':
            return tf.nn.tanh(x)
        if self.a == 'softmax':
            return tf.nn.softmax(x)


# Create Model

In [114]:
class CustomConvModel(tf.keras.Model):
    def __init__(self):
        super(CustomConvModel, self).__init__()
        self.conv1 = CustomConvLayer(weights['weights_conv_1'], masks['mask_conv_1'], biases['bias_conv_1'], 1, 'SAME')#'VALID')
        self.maxpool1 = CustomPoolLayer(k=2, padding='SAME')
        self.conv2 = CustomConvLayer(weights['weights_conv_2'], masks['mask_conv_2'], biases['bias_conv_2'], 1, 'VALID')
        self.maxpool2 = CustomPoolLayer(k=2, padding='VALID')
        #self.conv3 = CustomConvLayer(weights['weights_conv_3'], masks['mask_conv_3'], biases['bias_dense_1'], 1, 'VALID')
        self.dense1 = CustomDenseLayer(weights['weights_dense_1'], masks['mask_dense_1'], biases['bias_dense_1'], 'tanh')
        self.dense2 = CustomDenseLayer(weights['weights_dense_2'], masks['mask_dense_2'], biases['bias_dense_2'], 'tanh')
        self.dense3 = CustomDenseLayer(weights['weights_dense_3'], masks['mask_dense_3'], biases['bias_dense_3'], 'softmax')
        
    def call(self, inputs):
        print('input shape', inputs.shape)
        x = tf.reshape(inputs, shape=[-1,28, 28, 1])
        print('after reshape',x.shape)
        x = self.conv1(x)
        print('after conv1', x.shape)
        x = self.maxpool1(x)
        print('after pool1',x.shape)
        x = self.conv2(x)
        #print('after conv2',x.shape)
        x = self.maxpool2(x)
        #print('yo',x.shape)
        #x = layers.Flatten()(x)
        print('after pool2',x.shape)
        #x = self.conv3(x)
        
        print('after conv3',x.shape)
        x = layers.Flatten()(x)
        #print('after flatten',x.shape)
        x = self.dense1(x)
        x = self.dense2(x)
        print(x.shape)
        x =  self.dense3(x)
        print(x.shape)
        return x
        

In [115]:
model = CustomConvModel()

In [116]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) ,
              metrics=['accuracy'],
              experimental_run_tf_function=False
             )

In [117]:
model.fit(x=x_train,
          y=y_train,
          batch_size=64,
          epochs=20,
          validation_data=(x_test, y_test),
         )

Epoch 1/20
input shape (None, 784)
after reshape (None, 28, 28, 1)
inputs Tensor("custom_conv_model_16/Reshape:0", shape=(None, 28, 28, 1), dtype=float32)
weights * masks Tensor("custom_conv_model_16/custom_conv_layer_47/Mul:0", shape=(5, 5, 1, 6), dtype=float32)
x Tensor("custom_conv_model_16/custom_conv_layer_47/Conv2D:0", shape=(None, 28, 28, 6), dtype=float32)
after conv1 (None, 28, 28, 6)
after pool1 (None, 14, 14, 6)
inputs Tensor("custom_conv_model_16/custom_pool_layer_32/AvgPool2D:0", shape=(None, 14, 14, 6), dtype=float32)
weights * masks Tensor("custom_conv_model_16/custom_conv_layer_48/Mul:0", shape=(5, 5, 6, 16), dtype=float32)
x Tensor("custom_conv_model_16/custom_conv_layer_48/Conv2D:0", shape=(None, 10, 10, 16), dtype=float32)
after pool2 (None, 5, 5, 16)
after conv3 (None, 5, 5, 16)
bias  <tf.Variable 'Variable:0' shape=(120,) dtype=float32>
bias  <tf.Variable 'Variable:0' shape=(84,) dtype=float32>
(None, 84)
bias  <tf.Variable 'Variable:0' shape=(10,) dtype=float32>
(

<tensorflow.python.keras.callbacks.History at 0x15469fb50>

In [118]:
all_layers = model.get_weights()
for layer in all_layers:
    print(layer.shape)
    
print(all_layers[6])

(5, 5, 1, 6)
(6,)
(5, 5, 1, 6)
(5, 5, 6, 16)
(16,)
(5, 5, 6, 16)
(400, 120)
(120,)
(400, 120)
(120, 84)
(84,)
(120, 84)
(84, 10)
(10,)
(84, 10)
[[ 1.6544837  -1.2155174  -0.01104877 ... -1.4212358  -0.69011974
   0.7736215 ]
 [ 1.304302   -0.36460733  0.52945864 ...  2.3375378   1.7216002
  -0.23331481]
 [-0.42149517 -0.97565067  1.5907264  ... -0.33470914  1.1034594
  -0.4771902 ]
 ...
 [-1.4904262  -0.85657346 -1.2877682  ...  0.26765004  0.20881368
  -0.6255907 ]
 [-0.4263792   0.06906602 -0.01037085 ... -0.3516376   0.37083548
  -0.58330184]
 [ 0.85798514 -0.31575802 -1.3953352  ... -0.66209424 -1.23648
  -0.7591743 ]]


In [119]:
len(model.get_weights())

15

In [120]:
def prune_conv_layers(pruning_ratio):
    layer_to_prune = [0, 3]
    pruned_weights = model.get_weights()
    
    for layer in layer_to_prune:
        converted_weights = convert_from_hwio_to_iohw(model.get_weights()[layer])
        converted_mask = convert_from_hwio_to_iohw(model.get_weights()[layer + 2]).numpy()
        for input_index, input_layer in enumerate(converted_weights):

            for kernel_index, kernel in enumerate(input_layer):
                dims = kernel.shape
                flat_weights = kernel.numpy().flatten()
                flat_masks = converted_mask[input_index][kernel_index].flatten()
                flat_weights_df = pd.DataFrame(flat_weights)
                flat_mask_df = pd.DataFrame(flat_masks)
                no_of_weights_to_prune = int(len(flat_weights)*pruning_ratio)
                #print(no_of_weights_to_prune)
                indices_to_delete = flat_weights_df.abs().values.argsort(0)[:no_of_weights_to_prune]
                for idx_to_delete in indices_to_delete:
                    flat_masks[idx_to_delete] = 0

                converted_mask[input_index][kernel_index] = flat_masks.reshape(dims)
        back_converted_mask = convert_from_iohw_to_hwio(converted_mask)
        pruned_weights[layer+2] = back_converted_mask
    
    return pruned_weights


In [121]:
pruning_ratios = [0.0, .5, 0.8, 0.9]
pre_pruning_weight_archive = []
post_pruning_weight_archive = []
post_fine_tune_weight_archive = []
pre_fine_tune_results = []
post_fine_tune_results = []

for pruning_ratio in tqdm(pruning_ratios):
    pre_pruning_weight_archive.append(model.get_weights())
    pruned_weights = prune_conv_layers(pruning_ratio)
    model.set_weights(pruned_weights)
    pruned_weights = prune_weights(model, pruning_ratio)
    model.set_weights(pruned_weights)
    post_pruning_weight_archive.append(model.get_weights())
    pre_fine_tune_results.append(model.evaluate(x_test, y_test, verbose=0))
    model.fit(x=x_train,
          y=y_train,
          batch_size=64,
          epochs=1,
          validation_data=(x_test, y_test),
         )
    post_fine_tune_results.append(model.evaluate(x_test, y_test, verbose=0))
    post_fine_tune_weight_archive.append(model.get_weights())

  0%|          | 0/4 [00:00<?, ?it/s]

(120, 84)
9
10080
no of weights 0
weights to prune shape (10080,)
(84, 10)
12
840
no of weights 0
weights to prune shape (840,)


 25%|██▌       | 1/4 [00:33<01:40, 33.46s/it]

(120, 84)
9
10080
no of weights 5040
weights to prune shape (10080,)
(84, 10)
12
840
no of weights 420
weights to prune shape (840,)


 50%|█████     | 2/4 [00:53<00:59, 29.53s/it]

(120, 84)
9
10080
no of weights 8064
weights to prune shape (10080,)
(84, 10)
12
840
no of weights 672
weights to prune shape (840,)


 75%|███████▌  | 3/4 [01:15<00:27, 27.08s/it]

(120, 84)
9
10080
no of weights 9072
weights to prune shape (10080,)
(84, 10)
12
840
no of weights 756
weights to prune shape (840,)


100%|██████████| 4/4 [01:33<00:00, 23.30s/it]


In [97]:
def prune_weights(model, pruning_ratio):
    weights = model.get_weights()
    weights_to_prune = model.get_weights()
    for index, weight in enumerate(weights):
        
        if (index == 9) or (index == 12) :
            print(weight.shape)
            print(index)
            flat_weights = weight.flatten()
            flat_weights_to_prune = weights_to_prune[index+2].flatten()
            #print (flat_weights_to_prune.shape, flat_weights.shape)
            flat_weights_df = pd.DataFrame(flat_weights)
            flat_weights_to_prune_df = pd.DataFrame(flat_weights_to_prune)
            no_of_weights_to_prune = int(len(flat_weights)*pruning_ratio)
            print(len(flat_weights))
            print('no of weights',no_of_weights_to_prune)
            print('weights to prune shape', flat_weights_to_prune.shape)
            indices_to_delete = flat_weights_df.abs().values.argsort(0)[:no_of_weights_to_prune]
            for idx_to_delete in indices_to_delete:
                flat_weights_to_prune[idx_to_delete] = 0
            dims = weights_to_prune[index+2].shape
            weights_reshaped = flat_weights_to_prune.reshape(dims)
            weights_to_prune[index+2] = weights_reshaped
    #print(weights_to_prune)
    return weights_to_prune

In [None]:
model.evaluate(x_test, y_test)

In [123]:

for index,res in enumerate(pruning_ratios):
    print('pruning ratio: ', 
          res, 
          'accuracy before fine tuning: ',
          pre_fine_tune_results[index][1],
          'accuracy after fine tuning: ',
          post_fine_tune_results[index][1])

pruning ratio:  0.0 accuracy before fine tuning:  0.9325000047683716 accuracy after fine tuning:  0.9365000128746033
pruning ratio:  0.5 accuracy before fine tuning:  0.7851999998092651 accuracy after fine tuning:  0.9301999807357788
pruning ratio:  0.8 accuracy before fine tuning:  0.16220000386238098 accuracy after fine tuning:  0.6840999722480774
pruning ratio:  0.9 accuracy before fine tuning:  0.20149999856948853 accuracy after fine tuning:  0.7943999767303467


In [None]:
model.get_weights()[0].shape

In [None]:
print(convert_from_hwio_to_iohw(model.get_weights()[0]).shape)

In [63]:
def convert_from_hwio_to_iohw(weights_nchw):
    return tf.transpose(weights_nchw, [2, 3, 0, 1])



def convert_from_iohw_to_hwio(weights_nhwc):
    return tf.transpose(weights_nhwc, [2, 3, 0, 1])

In [None]:
preds = map (lambda pred: np.argmax(pred), model.predict(x_test[:30]))
preds = list(preds)

In [None]:


plt.figure(figsize=(10,10))
for i in range(25):
    img = tf.reshape(x_test[i], shape=[28, 28])
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(img)

    plt.xlabel(preds[i])
plt.show()

In [None]:
def weight_mask_variable(var, scope):

    mask = tf.Variable(initial_value = tf.ones(var.shape),
                       trainable=False
                      )
    return mask

def apply_mask(x, scope=''):
    #print(x.shape)
    mask = weight_mask_variable(x, scope)
    #threshold = weight_threshold_variable(x, scope)

    x = tf.cast(x, tf.float32)
    masked_weights = tf.multiply(mask, x)
    

    return masked_weights