<a href="https://colab.research.google.com/github/shreyashpatodia/for-ai-challenge/blob/master/for_ai_challenge.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Gotta prune 'em all

## Introduction

This is Shreyash Patodia's submission to for.ai's pruning challenge.

Reproducing the results from this colab should be as simple as running all the cells in order!

I've tried to add some context using text and comments to make things easier to understand.

## Understanding the challenge

* Write a ReLU activated neural network with hidden layer sizes [1000, 1000, 500, 200].

* Train network on Fashion-MNIST or MNIST. Choice: Fashion-MNIST.

* Prune network using weight pruning and unit pruning. __This is post-hoc pruning and no pruning needs to happen during the training of the network__. Pruning percentages: [0, 25, 50, 60, 70, 80, 90, 95, 97, 99].

* Visualize results

* Analyze Results

* _Bonus_: Speed up neural network execution using new found sparsity.

In [0]:
#@title Importing Libraries { form-width: "200px", display-mode: "form" }
from __future__ import print_function
from __future__ import division 
from __future__ import absolute_import 

import tensorflow as tf
from tensorflow import keras

import numpy as np
import pandas as pd
import altair as alt
import copy

tf.enable_eager_execution()

print('Tensorflow version: ', tf.__version__)
print("Executing eage")

Tensorflow version:  1.14.0
Executing eage


## The neural network

Defining a keras neural network that can take the input_shape, hidden_sizes and output_size as arguments (choosing keras.Sequential over keras.Model because the model is simple eough to not need any bespoke functionality)

In [0]:
def create_network(input_shape, hidden_sizes, output_size):
  """
  Creates a dense network based on the parameters provided.
  
  Create a network which takes input of shape, input_shape, flattens
  the input and then passes it through hidden layers whose sizes are
  parameterized by values in hidden_sizes and has an output of size,
  output_size.
  
  Args:
    input_shape: Shape of the input. For example, (28, 28).
    hidden_sizes: List of the no. of units for hidden layer. For 
      example, [1000, 1000, 500, 200].
    output_size: Number of classes in the output.
  
  Returns:
    model: The model of type keras.Sequential with the given specs.
  """
  
  layers = [keras.layers.Flatten(input_shape=input_shape)]
  for i in range(len(hidden_sizes)):
    layers.append(keras.layers.Dense(hidden_sizes[i],
                                     activation=tf.nn.relu,
                                     use_bias=False))
  layers.append(keras.layers.Dense(output_size, use_bias=False))
  model = keras.Sequential(layers)
  return model 
  

## The Data

Choice of dataset: __Fashion MNIST__.

We start of by definiing constants based on the data (we also define the constants needed for the neural network here) and then use the function fashion_mnist to get the data.

In [0]:
#@title Constants { run: "auto", vertical-output: true, display-mode: "form" }
input_shape = (28, 28) #@param {type:"raw"}
hidden_sizes = [1000, 1000, 500, 200] #@param {type:"raw"}
output_size = 3 #@param {type:"integer"}
batch_size = 128 #@param {type:"integer"}


In [0]:
def rescale_data(images):
  """
  Normalizes the images provided to a 0-1 scale.
  
  Args:
    images: inputs to be rescaled.
  
  """
  return tf.cast(images / 255.0, tf.float32)

def fashion_mnist(batch_size):
  """
  Loads fashion_mnist as tf.data.Dataset.
  
  Args:
    batch_size: Size the dataset should be batched into.
    
  Returns:
    training_data: Batched training_data.
    test_data: Batched test_data.
  """
  num_classes = 10
  prefetch_size = 10
  fashion_mnist = keras.datasets.fashion_mnist
  training_data, test_data = fashion_mnist.load_data()
  
  training_images, training_labels = training_data
  test_images, test_labels = test_data
  
  training_images = rescale_data(training_images)
  test_images = rescale_data(test_images)
  
  training_data = training_images, tf.one_hot(training_labels, num_classes)
  test_data = test_images, tf.one_hot(test_labels, num_classes)
  
  training_data = tf.data.Dataset.from_tensor_slices(training_data)
  training_data = training_data.batch(batch_size)
  training_data = training_data.shuffle(prefetch_size**2)
  training_data = training_data.prefetch(prefetch_size)
  
  test_data = tf.data.Dataset.from_tensor_slices(test_data)
  test_data = test_data.batch(batch_size)
  test_data = test_data.shuffle(prefetch_size**2)
  test_data = test_data.prefetch(prefetch_size)
  
  return training_data, test_data

In [0]:
training_data, test_data = fashion_mnist(batch_size)

## Training the network

We train the network for 30 epochs using the training data.

Optimizer used: Adam.


In [0]:
model = create_network(input_shape, hidden_sizes, output_size)

In [0]:
model.compile(optimizer='adam',
              loss=tf.losses.softmax_cross_entropy,
              metrics=['accuracy'])

In [0]:
history = model.fit(training_data, epochs=30, verbose=1)

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


In [0]:
# sanity check to make sure the model learnt something.
test_loss, test_acc = model.evaluate(test_data)



## Pruning Networks

We define a general prune_model function which takes a parameter called pruning_method which is used to prune the each layer of the network.

This allows us to implement layerwise weight and unit pruning and use them interchangably!

In [0]:
def prune_model(model_weights, pruning_percent, pruning_method):
  """
  Prune model goes through all the hidden layers of a network and prunes them.
  
  Args:
    model_weights: list of model_weights.
    pruning_percent: the percentile of weights/units to be pruned.
    pruning_method: can be one of `prune_layer_weights` and `prune_layer_units`.
  
  Returns:
    pruned_layerwise_weights: pruned weights of the network.
  """
  layerwise_hidden_weights = model_weights[:-1]
  output_weights = model_weights[-1]
  pruned_layerwise_weights = [] 
  for hidden_layer_weights in layerwise_hidden_weights:
    pruned_layerwise_weights.append(pruning_method(
        hidden_layer_weights, pruning_percent))
  pruned_layerwise_weights.append(np.copy(output_weights))
  return pruned_layerwise_weights


### Weight Pruning

We perform layerwise weight pruning to
zero out the bottom k% of the weights of each layer.

In [0]:
def prune_layer_weights(layer_weights, pruning_percent):
  """
  Prunes away `pruning_percent` weights from the network.
  
  Args:
    layer_weights: weights of the layer.
    pruning_percent: percent of weights to prune.
  """
  # find abs value 
  abs_layer_weights = np.absolute(layer_weights)
  # find threshold for pruning
  threshold = np.percentile(abs_layer_weights, pruning_percent)
  
  pruned_layer_weights = np.copy(layer_weights)
  
  # prune (make = 0 ) weights below the threshold
  pruned_layer_weights[abs_layer_weights < threshold] = 0
  
  return pruned_layer_weights

### Unit Pruning

In [0]:
def prune_layer_units(layer_weights, pruning_percent):
  """
  Prunes away `pruning_percent` units from the network.
  
  Args:
    layer_weights: weights of the layer.
    pruning_percent: percent of weights to prune.
  """
  # find columwise l2 norm
  layerwise_l2_norm = np.linalg.norm(layer_weights, axis=0)
  # find threshold based on l2 norm
  threshold = np.percentile(layerwise_l2_norm, pruning_percent)
   
  pruned_layer_weights = np.copy(layer_weights)
  
  # prune units based on l2 norm (transpose helps use broadcasting of l2_norm
  # for pruning).
  pruned_layer_weights = np.transpose(pruned_layer_weights) 
  pruned_layer_weights[layerwise_l2_norm < threshold] = 0
  pruned_layer_weights = np.transpose(pruned_layer_weights)
  
  return pruned_layer_weights



## Evaluating pruning methods


Defining a list of all the pruning percentages:

In [0]:
#@title
ks = [0, 25, 50, 60, 70, 80, 90, 95, 97, 99] #@param {type:"raw"}

In [0]:
def eval_model_pruning(model, pruning_method, test_data, ks):
  """
  Evaluates a model pruned using `pruning_method` across different percetiles.
  
  Args:
    model: model whose weights are to be pruned.
    pruning_method: pruning strategy. either of `pruned_layer_weights` or
      `prune_layer_units`.
    test_data: eval data.
    ks: list of the percentiles to be pruned away.
  """
  
  # model whose weights we'll set to pruned weights
  eval_model = create_network(input_shape, hidden_sizes, output_size)
  eval_model.compile(
      optimizer='adam',
      loss=tf.losses.softmax_cross_entropy,
      metrics=['accuracy']
  )
  model_weights = model.get_weights()
  pruning_losses = {}
  pruning_accuracies = {}
  for k in ks:
    pruned_layerwise_weights = prune_model(model_weights, k, pruning_method) 
    eval_model.set_weights(pruned_layerwise_weights)
    loss, acc = eval_model.evaluate(test_data)
    pruning_losses[k] = loss
    pruning_accuracies[k] = acc
  return pruning_losses, pruning_accuracies

### Weight Pruning

In [0]:
weight_pruning_losses, weight_pruning_accuracies = eval_model_pruning(
  model, prune_layer_weights, test_data, ks=ks)
  



## Unit Pruning

In [0]:
unit_pruning_losses, unit_pruning_accuracies = eval_model_pruning(
  model, prune_layer_units, test_data, ks=ks)



## Visualize Results

In [0]:
def visualize_results(accuracies):
  """
  Helps create a table and plot charts for accuracy vs pruning.
  
  Args:
    accuracies: a dict with keys 
  """
  
  accuracies = pd.DataFrame({
      'pruning_percent': list(accuracies.keys()),
      'accuracies': list(accuracies.values())
  })
  
  print(accuracies)
  
  chart = alt.Chart(accuracies, height=300, width=300).mark_line().encode(
      x='pruning_percent',
      y='accuracies'
  ).properties(background='white').interactive()
  
  chart.display()
  

### Visualize weight pruning

In [0]:
visualize_results(weight_pruning_accuracies)

   pruning_percent  accuracies
0                0      0.8880
1               25      0.8881
2               50      0.8880
3               60      0.8879
4               70      0.8840
5               80      0.8656
6               90      0.7716
7               95      0.3710
8               97      0.1938
9               99      0.0828


### Visualize Unit Pruning

In [0]:
visualize_results(unit_pruning_accuracies)

   pruning_percent  accuracies
0                0      0.8880
1               25      0.8881
2               50      0.8796
3               60      0.8249
4               70      0.7930
5               80      0.3373
6               90      0.2175
7               95      0.1453
8               97      0.0954
9               99      0.1266


## Analyzing Results

### General Takeaway

The network trained on Fashion-MNIST seems to be fairly robust to post-hoc pruning irrespective of the pruning strategy. We can prove off large fractions of the neural network without seeing a considerable (or proportionate) drop in performance (25% reduction in network size does not decrease accuracy by 25%). 

The fact that we can prune off such large parts of the network enforces the lottery ticket hypothesis and the fact that some parts of the network are more important than others. It also shows us that the magnitude of the weights/units seems to be a good indicator for their importance to the network which is in line with the intuition that larger weights contribute more to the output and thus, are seemingly more important. 

I also feel like there might be redundancies in some parts of the network's learning leading it to be robust to weight pruning of upto 70% without showing barely any loss in accuracy. It is only when all the parts learning a specific "feature" are lost that performance truly degrades. This is especially true as we go on to prune weights in the network that are not small and are amongst the top 50% of the weights. These weights might have actually had some learning associated with them but some other part of the network might have learnt similar feature but might have done this learning better making the pruned portion less useful and thus, expendable.

### Comparing Weight Pruning and Unit Pruning

I think of unit pruning as being a special case of weight pruning in the sense that if all the weights in a unit were "not important" based on their magnitude to the network then they'd be pruned away by weight pruning. This led to have the intuition that in any case weight pruning should perform at least as well if not better than unit pruning.

My inutition was reinforced by the results where at pruning_percent = 70 weight pruning had barely suffered any loss in performance (< 1%) but unit pruning had gone down by almost 10%. This might be because some units might not have a lot of influential weights but by pruning them away we are removing the influential weights that they might have leading to the loss in performance.

The downside with weight pruning is that it leads to possibly sparse tensors and optimizing inference time over sprase tensors isn't the easiest. With unit pruning, it is straightforward to remove some neurons and the gains in performance by removing entire columns from our matrices might be signifcantly higher than make our tensors more sparse. 

## Bonus

### Unit Level Sparsity

It is fairly straightforward to remove
all columns in a matrix which are zero-ed out. But we also want to remove the rows in the next matrix for which the columns in the current one have to zeroed out.

Here is an implementation which uses masks to achieve this:

In [0]:
k = 50

model_weight = model.get_weights()
pruned_layerwise_weights = prune_model(model_weights, k, prune_layer_units) 
flattened_input_shape = pruned_layerwise_weights[0].shape[0]

# row_bool_mask removes rows (corresponding to connection with a specific
# neuron from the previous layer for all the neurons in the current layer)
# initialize to True because we don't prune inputs
row_bool_mask = tf.fill([flattened_input_shape], True)

minimized_hidden_sizes = []
minimized_layerwise_weights = []
for pruned_layer_weights in pruned_layerwise_weights:
  
  layer_weights = tf.identity(pruned_layer_weights)
  # remove rows based on the neurons from the previous layers that have been
  # pruned
  layer_weights = tf.boolean_mask(layer_weights, row_bool_mask)
  
  # find which columns are all zeros and create a mask appropriately
  intermediate_tensor = tf.reduce_sum(
      tf.abs(layer_weights), axis=0)
  zero_vector = tf.zeros(shape=(1,1), dtype=tf.float32)
  col_bool_mask = tf.squeeze(tf.not_equal(intermediate_tensor, zero_vector))
  
  # transpose -> prune -> transpose of transpose to get back needed shape
  layer_weights = tf.transpose(layer_weights)
  layer_weights = tf.boolean_mask(layer_weights, col_bool_mask)
  layer_weights = tf.transpose(layer_weights)
  
  minimized_hidden_sizes.append(layer_weights.shape[1].value)
  minimized_layerwise_weights.append(layer_weights.numpy())
  row_bool_mask = col_bool_mask
  
# remove output size
minimized_hidden_sizes = minimized_hidden_sizes[:-1]

mini_network = create_network(input_shape, minimized_hidden_sizes, output_size)
mini_network.compile(
    optimizer='adam',
    loss=tf.losses.softmax_cross_entropy,
    metrics=['accuracy']
    
)
mini_network.set_weights(minimized_layerwise_weights)
mini_network.evaluate(test_data)
  
  





[0.40370800431016124, 0.8796]

Same accuracy for 50% sparsity of units as the larger network containing the zero vectors!

This will obviously speed up execution because it does smaller computations.

### Weight Level Sparsity

We could use tf.sparse and convert our dense tensors with many many zeros to a sparse tensor to make it more efficient in terms of memory.

I am not sure using sparse tensors would lead to an obvious speedup for the small network but for very sparse tensors which would have very large sizes I can completely see sparse version of the operations needed at inference time like matrix multiplication, maximum (for relu) lead to quicker execution.

In [0]:
%%html
<marquee style='height=100px; width: 100%; color: red;'><b>That's all folks!</b></marquee>