# Importance Critorion Pruned Deep Neural Network

The idea behind this network is to compute the relative importance of each node in a layer of a trained neural network in determining the output over the entire training dataset. This way, the parameters that contribute little (on average) to classification of the output (independently of whether they lead to minimising the cost, unlike a backpropagation algorithm) will get a low importance score. Based on those importance scores, the network can than be pruned, reducing its size. The new network can potentially lead to a better performance, due to higher generalisation of the kept parameters. 

Below code implements calculations of this criterion on a very simple model: 3 layer fully connected NN on the MNIST dataset.

In [1]:
import tensorflow as tf
import numpy as np
import time

In [2]:
from tensorflow.examples.tutorials.mnist import input_data

In [3]:
def load_data():
    """Loads mnist data"""
    return input_data.read_data_sets("MNIST_data/", one_hot=True)

mnist = load_data()

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [4]:
def _weight_variable(shape, random_seed=0):
    """Create a weight variable with appropriate initialization"""
    initial = tf.truncated_normal(shape, stddev=0.1, seed=random_seed)
    return tf.Variable(initial, name='weights')


def _bias_variable(shape):
    """Create a bias variable with appropriate initialization"""
    initial = tf.constant(0.2, shape=shape)
    return tf.Variable(initial, name='biases')

In [5]:
x0 = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])

w0 = _weight_variable([784, 800])
b0 = _bias_variable([800])
x1 = tf.matmul(x0, w0)
y1 = tf.nn.relu(x1 + b0)

w1 = _weight_variable([800, 800])
b1 = _bias_variable([800])
x2 = tf.matmul(x1, w1)
y2 = tf.nn.relu(x2 + b1)

w2 = _weight_variable([800, 10])
b2 = _bias_variable([10])
x3 = tf.matmul(x2, w2) + b2

global_step = tf.train.get_or_create_global_step()

In [6]:
starter_learning_rate = 1e-1
lr = tf.train.exponential_decay(starter_learning_rate, global_step,
                                           10000, 0.9, staircase=True)
# lr = 0.01
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=x3))
train_step = tf.train.AdamOptimizer(learning_rate=lr).minimize(cross_entropy, global_step=global_step)

# Accuracy metrics
correct_prediction = tf.equal(tf.argmax(x3, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

try:
    for i in range(60000):
        batch_xs, batch_ys = mnist.train.next_batch(100)
        sess.run(train_step, feed_dict={x0: batch_xs, y_:batch_ys})
        if i % 200 == 0:
            train_accuracy = accuracy.eval(feed_dict={x0: batch_xs, y_: batch_ys})
            step, current_lr = sess.run([global_step, lr])
            print("Step {}, Accuracy: {}, lr: {}".format(step, train_accuracy, current_lr), end='\r')
except KeyboardInterrupt:
    pass

Step 19801, Accuracy: 0.9100000262260437, lr: 0.08999999612569809

In [7]:
# Evaluation:

baseline_accuracy = accuracy.eval(feed_dict={x0: mnist.test.images,
                                             y_: mnist.test.labels})
print('Test accuracy {}'.format(baseline_accuracy))

Test accuracy 0.8860999941825867


In [8]:
# Save the weights and biases as numpy arrays
w0_save, b0_save, w1_save, b1_save, w2_save, b2_save, last_step = sess.run([w0, b0, w1, b1, w2, b2, global_step])

In [16]:
# Eventually load the variables from here to keep on training the model.
tf.reset_default_graph()

x0 = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])

w0 = tf.Variable(w0_save)
b0 = tf.Variable(b0_save)
x1 = tf.matmul(x0, w0)
y1 = tf.nn.relu(x1 + b0)

w1 = tf.Variable(w1_save)
b1 = tf.Variable(b1_save)
x2 = tf.matmul(x1, w1)
y2 = tf.nn.relu(x2 + b1)

w2 = tf.Variable(w2_save)
b2 = tf.Variable(b2_save)
x3 = tf.matmul(x2, w2) + b2

global_step = tf.Variable(last_step, trainable=False)

# Compute the importance

In [10]:
# Recreate the graph:
tf.reset_default_graph()

x0 = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])

w0 = tf.Variable(w0_save)
b0 = tf.Variable(b0_save)
x1 = tf.matmul(x0, w0)
y1 = tf.nn.relu(x1 + b0)

w1 = tf.Variable(w1_save)
b1 = tf.Variable(b1_save)
x2 = tf.matmul(x1, w1)
y2 = tf.nn.relu(x2 + b1)

w2 = tf.Variable(w2_save)
b2 = tf.Variable(b2_save)
x3 = tf.matmul(x2, w2) + b2

In [11]:
batch_size = 100
n_outputs = 10

# Define the required asolute values:
y0_abs = tf.abs(x0)
w0_abs = tf.abs(w0)

y1_abs = tf.abs(y1)
w1_abs = tf.abs(w1)

y2_abs = tf.abs(y2)
w2_abs = tf.abs(w2)

# Define i3
i3 = tf.constant(np.ones([batch_size, n_outputs], dtype=np.float32) / n_outputs)

In [12]:
# Define the backpropagation of importance
i2 = y2_abs * (tf.matmul((i3 / (tf.matmul(y2_abs, w2_abs))), tf.transpose(w2_abs)))

i1 = y1_abs * (tf.matmul((i2 / (tf.matmul(y1_abs, w1_abs))), tf.transpose(w1_abs)))

i0 = y0_abs * (tf.matmul((i1 / (tf.matmul(y0_abs, w0_abs))), tf.transpose(w0_abs)))

In [13]:
# Average the results for the batch
i2_batch_avg = tf.reduce_mean(i2, axis=0)
i1_batch_avg = tf.reduce_mean(i1, axis=0)
i0_batch_avg = tf.reduce_mean(i0, axis=0)

In [15]:
"""
Compute the importances over the chosen number of batches from the training set while maintaining
cumulative moving average.
"""
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

# Create numpy arrays in which the moving average will be stored:
i0_avg, i1_avg, i2_avg = np.zeros([1, 784], dtype=np.float64), np.zeros([1, 800], dtype=np.float64), np.zeros([1, 800], dtype=np.float64)

for i in range(600): # 600 batches - iterate over the entire training set
    n = i + 1 # The number of elements of the average
    batch_xs, batch_ys = mnist.train.next_batch(batch_size)
    i0_curr, i1_curr, i2_curr = sess.run([i0_batch_avg, i1_batch_avg, i2_batch_avg],
                                         feed_dict={x0: batch_xs, y_:batch_ys})
    
    # Update the averages:
    i2_avg = (i2_avg * i + i2_curr) / (i + 1)
    i1_avg = (i1_avg * i + i1_curr) / (i + 1)
    i0_avg = (i0_avg * i + i0_curr) / (i + 1)
    
    if n % 10 == 0:
        print("Step: {}".format(n), end='\r')

Step: 600

In [21]:
print("The average importance of the second layer is: {:0.6f}".format(i1_avg.mean()))

The average importance of the second layer is: 0.001250


In [22]:
i0_avg = np.reshape(i0_avg, [784])
i1_avg = np.reshape(i1_avg, [800])
i2_avg = np.reshape(i2_avg, [800])

In [25]:
removal_thr = 0.5 # Threshold for how small the importance needs to be for the node to be removed.

print("Number of nodes that would be removed from the second layer (out of 800 total): {}".format(
    (i1_avg < i1_avg.mean()*removal_thr).sum()))

Number of nodes that would be removed from the second layer (out of 800 total): 229


In [26]:
l1_remove = (i1_avg < i1_avg.mean()*removal_thr).astype(int)
l2_remove = (i2_avg < i2_avg.mean()*removal_thr).astype(int)
l1_ix = np.argwhere(l1_remove)
l2_ix = np.argwhere(l2_remove)

In [27]:
# Remove the corresponding weights for:

# 1st hidden layer:
w0_save_reduced = np.delete(w0_save, l1_ix, axis=1)
w1_save_reduced = np.delete(w1_save, l1_ix, axis=0)
b0_save_reduced = np.delete(b0_save, l1_ix)

# 2nd hidden layer:
w1_save_reduced = np.delete(w1_save_reduced, l2_ix, axis=1)
w2_save_reduced = np.delete(w2_save, l2_ix, axis=0)
b1_save_reduced = np.delete(b1_save, l2_ix)

In [29]:
# Print the new size of each layer
print(w0_save_reduced.shape)
print(w1_save_reduced.shape)
print(w2_save_reduced.shape)

(784, 571)
(571, 499)
(499, 10)


## Test the thinned model

In [30]:
# Recreate the graph with pruned parameters:
tf.reset_default_graph()

x0 = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])

w0 = tf.Variable(w0_save_reduced)
b0 = tf.Variable(b0_save_reduced)
x1 = tf.matmul(x0, w0)
y1 = tf.nn.relu(x1 + b0)

w1 = tf.Variable(w1_save_reduced)
b1 = tf.Variable(b1_save_reduced)
x2 = tf.matmul(x1, w1)
y2 = tf.nn.relu(x2 + b1)

w2 = tf.Variable(w2_save_reduced)
b2 = tf.Variable(b2_save)
x3 = tf.matmul(x2, w2) + b2

In [31]:
# Accuracy metrics
correct_prediction = tf.equal(tf.argmax(x3, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

In [32]:
# Evaluation:
pruned_accuracy = accuracy.eval(feed_dict={x0: mnist.test.images,
                                           y_: mnist.test.labels})

print('Baseline accuracy: {}'.format(baseline_accuracy))
print('Pruned accuracy:   {}'.format(pruned_accuracy))

Baseline accuracy: 0.8860999941825867
Pruned accuracy:   0.902400016784668
