# Pruning deep neural networks － MTCNN 

Large models are memory-intensive with millions of parameters. Moving around all of the data required to compute inference results consumes energy. Many of the layers are bandwidth-bound, which means that the execution latency is dominated by the available bandwidth. The storage and transfer of large neural networks is also a challenge. Network pruning can reduce the footprint of a neural network, increase its inference speed and save energy, reduce the amount of bandwidth and compute required. A related idea motivating pruning is that models are over-parametrized and contain redundant logic and features which don't contribute a lot to the output.

There are different types of sparsity patterns, ranging from irregular to regular as shown below. The simplest case is the element-wise sparity --- **fine-grained pruning**. The use of specialized hardware to see a performance gain from fine-grained weights sparsity is needed. In that case, the tensors are produced as sparse at the element granularity. The weight tensors are not reduced in size since the zero-coefficients are still present. Some NN accelerators (ASICs) take advantage of fine-grained sparsity by using a compressed representation of sparse tensors. 

Coarse-grained pruning referred to as **structured pruning, group pruning or block pruning**. Structured-pruning such as Channel and filter pruning create compressed models that do not require special hardware to execute. This makes this form of structured pruning particularly interesting and popular. Convolution weights are 4D:(F, C, K, K) where F is the number of filters, C is the number of channels, and K is the kernel size. A kernal is a 2D matrix (K, K) that is a part of a 3D filter. 

<img src="imgs/1.png"  width="600" style="float: left;">

Recent works advocate "Structured sparsity" where entire conv2d filters are pruned. if less filters operating on a certain layer, less output feature maps will be.  In the most commom scenario, besides we reconfigure the convolution layer by changing the "out_channels"and the corresponding channels of linked activation layer or batch normalization layer, the following convolution layer "in_channels" needs to be changed. The following layer's weights need to be shrinked by removing the in_channels corresponding to the filters we prunned. If the following layer is a fully connected layer, the corresponding neurons will be discarded. The recent observation shows that the deeper the layer, the more it will get pruned.  

It is a **data-dependency** type of prunning. The most state of art DNNs use more complicated structures such as resnet, mobilenet, inception layer. The prunning strategy will be quite different. I will explain how to prune the mobilefacenet in a separate tutorial. So far, understanding the above prunning logic is enough for MTCNN prunning 

<img src="imgs/2.png"  width="400" style="float: left;">

## Sparsity Definition 

Sprasity is a measure of how many elements in a tensor are exact zeros. The L0 **norm** function measures how many zero-elements are in a tensor x. In other words, an element contributes either a value of 1 or 0 to L0. Anything but an exact zero contributies a value of 1.

<img src="imgs/3.png"  width="250" style="float: left;">

## Pruning Schedule

The most straight-forward to prune is to take a trained model and prune it once; also called one-shot pruning. However, employing a pruning-followed-by-retraining regimen can achieve much better results (higher sparsity at no accuracy loss). This is called **iterative pruning**, and the retraining that follows pruning is often referred as **fine-tuning**.  The iterative pruning can be considered as repeatedly learning which weights are important, removing the least important ones and then retraining the model to let it "recover" from the prunning by adjusting the remaining weights. At each iteration, we prune more weights. 

<img src="imgs/4.png"  width="200" style="float: left;">

## Pruning Criteria

Pruning requires a criteria for choosing which elements / kernals / filters to prune - this is called the pruning criteria. The most common criteria is the L1 norm of the weights of each filter. For each pruning iteration, all the filters are ranked, the m lowest ranking filters are prunned, retrain and repeat. The more sophisticated ranking approach is to rank the filters based on the effect of each on the network cost. That is, the network cost change will be minimal when pruning them. The ranking method is based on a first order of taylor expansion of the network cost function 

<img src="imgs/5.png"  width="500" style="float: left;">

theta is the ranking score that we care. hi refers to a filter to be pruned. C is the total cost. The ranking of filter **h** becomes the abs of first order of taylor expansion on the derivative of cost on corresponding feature map. For example if the feature map (activation) is in shape of (32x256x112x112) - (batch_size x channel x kernal x kernal), the corresponding gradient of cost will be the same shape. The point wise multiplication of each activation in the batch and it's gradient is averaged except the dimension of output leads to a 256 sized vector representing the ranks of the 256 filters in this layer. The ranking of each layer are then normalized by the L2 norm of the ranks in that layer which is believed as a empiric behavior   

The whole idea is from [Nvidia](https://arxiv.org/abs/1611.06440). In the paper their method outperformed other methods in accuracy. 

The following section will step-by-step explain how to determine the right filters to be pruned and how to filter the model. You may want to refer the detailed codes for a complete iterative prunning process.  

## Determine the Filters to be Prunned 

Let's rock on the MTCNN Prunning. Import the MTCNN Network and take a peek on MTCNN model structure

In [1]:
import torch
import numpy as np
from MTCNN.Base_Model.MTCNN_nets import PNet, RNet, ONet

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
pnet = PNet(is_train=True).to(device)
pnet.load_state_dict(torch.load("MTCNN/Base_Model/pnet_Weights", map_location=lambda storage, loc: storage))
rnet = RNet(is_train=True).to(device)
rnet.load_state_dict(torch.load("MTCNN/Base_Model/rnet_Weights", map_location=lambda storage, loc: storage))
onet = ONet(is_train=True).to(device)
onet.load_state_dict(torch.load("MTCNN/Base_Model/onet_Weights", map_location=lambda storage, loc: storage))

pnet.train()
rnet.train()
onet.train()
print(pnet)
print(rnet)
print(onet)

PNet(
  (features): Sequential(
    (conv1): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1))
    (prelu1): PReLU(num_parameters=10)
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
    (conv2): Conv2d(10, 16, kernel_size=(3, 3), stride=(1, 1))
    (prelu2): PReLU(num_parameters=16)
    (conv3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
    (prelu3): PReLU(num_parameters=32)
  )
  (conv4_1): Conv2d(32, 2, kernel_size=(1, 1), stride=(1, 1))
  (conv4_2): Conv2d(32, 4, kernel_size=(1, 1), stride=(1, 1))
)
RNet(
  (features): Sequential(
    (conv1): Conv2d(3, 28, kernel_size=(3, 3), stride=(1, 1))
    (prelu1): PReLU(num_parameters=28)
    (pool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
    (conv2): Conv2d(28, 48, kernel_size=(3, 3), stride=(1, 1))
    (prelu2): PReLU(num_parameters=48)
    (pool2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
    (conv3): Conv2d(48, 64, kernel_si

The procedure is to look over all the feature maps derived from conv2d layers, calculate the first order of taylor expansion, rank all the conv2d filters and prune the m lowest ranking filters. 

Ideally, the filter ranking should be concluded by using all the training data. For demonstrated purpose, we only use one single fake batch data here including input images, ground truth label and ground truth offset for bounding boxes. We will use PNet as an example

In [3]:
input_images = torch.randn(3,3,12,12)
gt_label = torch.Tensor([1,0,-1]).type(torch.LongTensor)
gt_offset = torch.randn(3,4).type(torch.FloatTensor)

In order to obtain the grade of the intermedia feature map on the total loss, we hereby use **register_hook** function. 

A **FilterPrunner** Class is built as below:

1. The **forward** function will append the generated intermedia feature maps from conv2d layer to activations dict. 

2. The **compute_rank** function will pointwisely multiply the grad with activation and average the tayor value for each filter, build up a filter ranks dict during loss backward calculation 

3. **normalize_ranks_per_layer** to normalize the taylor value of each filter in that layer

4. **lowest_ranking_filters** to use "nsmallest" function for filter ranking

5. **get_prunning_plan** to obtain the layer index and filter index that will be prunned

In [4]:
import torch
import numpy as np 
from operator import itemgetter
from heapq import nsmallest

class FilterPrunner:
    def __init__(self, model, use_cuda = False):
        self.model = model
        self.reset()
        self.use_cuda = use_cuda

    def reset(self):
        self.filter_ranks = {}

    def forward(self, x):
        self.activations = []
        self.gradients = []
        self.grad_index = 0
        self.activation_to_layer = {}

        activation_index = 0
        for layer, (name, module) in enumerate(self.model.features._modules.items()):
            x = module(x)
            if isinstance(module, torch.nn.modules.conv.Conv2d):
                x.register_hook(self.compute_rank)
                self.activations.append(x)
                self.activation_to_layer[activation_index] = layer # the ith conv2d layer
                activation_index += 1
                
        a = self.model.conv4_1(x)
        b = self.model.conv4_2(x)
        c = None 

        return c, b, a

    def compute_rank(self, grad):
        activation_index = len(self.activations) - self.grad_index - 1
        activation = self.activations[activation_index]
        print("activation shape is: ", activation.shape)
        taylor = activation * grad

        # Get the average value for every filter,
        # accross all the other dimensions
        taylor = taylor.mean(dim=(0, 2, 3)).data
        print("taylor shape is: ", taylor.shape)
        if activation_index not in self.filter_ranks:
            self.filter_ranks[activation_index] = \
                torch.FloatTensor(activation.size(1)).zero_()

            if self.use_cuda:
                self.filter_ranks[activation_index] = self.filter_ranks[activation_index].cuda()

        self.filter_ranks[activation_index] += taylor
        self.grad_index += 1
        
    def lowest_ranking_filters(self, num):
        data = []
        for i in sorted(self.filter_ranks.keys()):
            for j in range(self.filter_ranks[i].size(0)):
                data.append((self.activation_to_layer[i], j, self.filter_ranks[i][j]))

        return nsmallest(num, data, itemgetter(2))

    def normalize_ranks_per_layer(self):
        for i in self.filter_ranks:
            v = torch.abs(self.filter_ranks[i]).cpu()
            v = v / np.sqrt(torch.sum(v * v))
            self.filter_ranks[i] = v

    def get_prunning_plan(self, num_filters_to_prune):
        filters_to_prune = self.lowest_ranking_filters(num_filters_to_prune)
                
        filters_to_prune_per_layer = {}
        for (l, f, _) in filters_to_prune:
            if l not in filters_to_prune_per_layer:
                filters_to_prune_per_layer[l] = []
            filters_to_prune_per_layer[l].append(f)
    
    
        for l in filters_to_prune_per_layer:
            filters_to_prune_per_layer[l] = sorted(filters_to_prune_per_layer[l])

        return filters_to_prune_per_layer

The next step will be to calculate the loss, carry out backpropagation to obtain grad and build up the filter_ranks. Note that if the whole training data is used for filter ranking, the taylor value for each filter will be accumulated through batches. The ranking could be more accurate. 

In [5]:
import torch.nn as nn
prunner = FilterPrunner(pnet)
prunner.reset()

loss_cls = nn.CrossEntropyLoss()
loss_offset = nn.MSELoss()

pnet.zero_grad()

with torch.set_grad_enabled(True):
        _, pred_offsets, pred_label = prunner.forward(input_images)
        pred_offsets = torch.squeeze(pred_offsets)
        pred_label = torch.squeeze(pred_label)
        # calculate the cls loss
        # get the mask element which >= 0, only 0 and 1 can effect the detection loss
        mask_cls = torch.ge(gt_label, 0)
        valid_gt_label = gt_label[mask_cls]
        valid_pred_label = pred_label[mask_cls]

        # calculate the box loss
        # get the mask element which != 0
        unmask = torch.eq(gt_label, 0)
        mask_offset = torch.eq(unmask, 0)
        valid_gt_offset = gt_offset[mask_offset]
        valid_pred_offset = pred_offsets[mask_offset]

        loss = torch.tensor(0.0).to(device)

        if len(valid_gt_label) != 0:
            loss += 0.02*loss_cls(valid_pred_label, valid_gt_label)

        if len(valid_gt_offset) != 0:
            loss += 0.6*loss_offset(valid_pred_offset, valid_gt_offset)

        loss.backward()

for i , taylor in prunner.filter_ranks.items():
    print(i, taylor.shape)

activation shape is:  torch.Size([3, 32, 1, 1])
taylor shape is:  torch.Size([32])
activation shape is:  torch.Size([3, 16, 3, 3])
taylor shape is:  torch.Size([16])
activation shape is:  torch.Size([3, 10, 10, 10])
taylor shape is:  torch.Size([10])
2 torch.Size([32])
1 torch.Size([16])
0 torch.Size([10])


 Normalize the taylor value by the L2 norm of the ranks in each layer

In [6]:
prunner.normalize_ranks_per_layer()

Assume we would like to prune 10 filters. **lowest_ranking_filters** will rank the filter based on the taylor value and return the 10 smallest filters in the format of (layer_index, filter_index, taylor value) 

In [7]:
num_filters_to_prune = 10
filters_to_prune = prunner.lowest_ranking_filters(num_filters_to_prune) 
filters_to_prune

[(5, 21, tensor(0.0002)),
 (5, 20, tensor(0.0003)),
 (5, 1, tensor(0.0003)),
 (5, 16, tensor(0.0003)),
 (5, 0, tensor(0.0009)),
 (3, 8, tensor(0.0009)),
 (5, 8, tensor(0.0010)),
 (5, 18, tensor(0.0012)),
 (5, 29, tensor(0.0025)),
 (5, 11, tensor(0.0028))]

**get_prunning_plan** will return a dict such that key: layer index, values: filter index 

In [8]:
filters_to_prune_per_layer = prunner.get_prunning_plan(num_filters_to_prune)
filters_to_prune_per_layer

{5: [0, 1, 8, 11, 16, 18, 20, 21, 29], 3: [8]}

## How to Prune the MTCNN Network

We already find out the layer index and filter index with smallest taylor values. These filters are what we would like to take out from Network. But how to prune the network ? Let's move on 

There are two scenarios on prunning MTCNN net. Once a certain conv2d layer with certian filters is choosen to be pruned, we will first prune the corresponding channels on conv2d, bounded activation layer (i.e. PReLU) and batch normalization layer. There is no need for pooling layer. If (A) the next linked layer is a Conv2d layer, we need to reconfige the in-channels of this layer. if (B) the next linked layer is a fully connected layer, the corresponding input neurons will need to be discarded. 

Let's use onet as an example since onet has both conv layer and linear layer. Assume we want to prune layer 0 with filter 3 and 6

In [9]:
layer_index = 0
filter_index = (3,6)

_, conv = list(onet.features._modules.items())[layer_index]
_, PReLU = list(onet.features._modules.items())[layer_index+1]

The new conv and PReLu can be constructed by removing the corresponding filter weights 

In [10]:
new_conv = \
        torch.nn.Conv2d(in_channels=conv.in_channels, \
                        out_channels=conv.out_channels - len(filter_index),
                        kernel_size=conv.kernel_size, \
                        stride=conv.stride,
                        padding=conv.padding,
                        dilation=conv.dilation,
                        groups=conv.groups,
                        bias=(conv.bias is not None))

old_weights = conv.weight.data.cpu().numpy()  
print("conv old_weight shape is: ", old_weights.shape)
new_weights = np.delete(old_weights, filter_index, axis=0)  
new_conv.weight.data = torch.from_numpy(new_weights)
print("conv new_weight shape is: ", new_weights.shape)

bias_numpy = conv.bias.data.cpu().numpy()
print("conv old_bias shape is: ", bias_numpy.shape)
bias = np.delete(bias_numpy, filter_index)
new_conv.bias.data = torch.from_numpy(bias)
print("conv new_bias shape is: ", bias.shape)

# The new PReLU layer constructed as follow:    
new_PReLU = torch.nn.PReLU(num_parameters=PReLU.num_parameters-len(filter_index))
old_weights = PReLU.weight.data.cpu().numpy()
print("PReLU old_weight's shape is: ", old_weights.shape)
new_weights = np.delete(old_weights, filter_index)
new_PReLU.weight.data = torch.from_numpy(new_weights)
print("PReLU new_weight's shape is: ", new_weights.shape)

conv old_weight shape is:  (32, 3, 3, 3)
conv new_weight shape is:  (30, 3, 3, 3)
conv old_bias shape is:  (32,)
conv new_bias shape is:  (30,)
PReLU old_weight's shape is:  (32,)
PReLU new_weight's shape is:  (30,)


Find the next conv to be pruned. The in_channels need to be reconstructed and corresponding weights will be removed. Note that the linked activation layer or batch normalization layer do not need to be edited since the output channels are not changed. 

In [11]:
next_conv = None
offset = 1
while layer_index + offset < len(onet.features._modules.items()):
    res = list(onet.features._modules.items())[layer_index + offset]
    if isinstance(res[1], torch.nn.modules.conv.Conv2d):
        next_name, next_conv = res
        break
    offset = offset + 1

print(next_conv)

Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))


In [12]:
if not next_conv is None:
    next_new_conv = \
        torch.nn.Conv2d(in_channels=next_conv.in_channels - len(filter_index), \
                        out_channels=next_conv.out_channels, \
                        kernel_size=next_conv.kernel_size, \
                        stride=next_conv.stride,
                        padding=next_conv.padding,
                        dilation=next_conv.dilation,
                        groups=next_conv.groups,
                        bias=(next_conv.bias is not None))

    old_weights = next_conv.weight.data.cpu().numpy() 
    print("conv old_weight shape is: ", old_weights.shape)
    new_weights = np.delete(old_weights, filter_index, axis=1)  
    next_new_conv.weight.data = torch.from_numpy(new_weights)
    print("conv new_weight shape is: ", new_weights.shape)

    next_new_conv.bias.data = next_conv.bias.data  # bias is not changed

conv old_weight shape is:  (64, 32, 3, 3)
conv new_weight shape is:  (64, 30, 3, 3)


Replace the layers with new constructed ones 

In [13]:
def replace_layers(model, i, indexes, layers):

    """
    replace conv layers of model.feature

    :param model:
    :param i: index of model.feature
    :param indexes: array of indexes of layers to be replaced
    :param layers: array of new layers to replace
    :return: model with replaced layers
    """
    if i in indexes:
        return layers[indexes.index(i)]
    return model[i]

In [14]:
features = torch.nn.Sequential(
    *(replace_layers(onet.features, i, [layer_index, layer_index+1, layer_index + offset], \
                     [new_conv, new_PReLU, next_new_conv]) for i, _ in enumerate(onet.features)))
del onet.features  # reset
del conv # reset

onet.features = features
onet

ONet(
  (conv6_1): Linear(in_features=256, out_features=2, bias=True)
  (conv6_2): Linear(in_features=256, out_features=4, bias=True)
  (conv6_3): Linear(in_features=256, out_features=10, bias=True)
  (features): Sequential(
    (0): Conv2d(3, 30, kernel_size=(3, 3), stride=(1, 1))
    (1): PReLU(num_parameters=30)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
    (3): Conv2d(30, 64, kernel_size=(3, 3), stride=(1, 1))
    (4): PReLU(num_parameters=64)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (7): PReLU(num_parameters=64)
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
    (9): Conv2d(64, 128, kernel_size=(2, 2), stride=(1, 1))
    (10): PReLU(num_parameters=128)
    (11): Flatten()
    (12): Linear(in_features=1152, out_features=256, bias=True)
    (13): Dropout(p=0.25)
    (14): PReLU(num_parameters=256)
  )
)

If the next linked layer is a fully connected layer, the corresponding input neurons will need to be discarded. Assume we want to prune the conv2d layer index 9 with filter index 4, 30. 

We will omit the conv2d layer reconstruction process here but only the fully connected layer. 

In [15]:
layer_index = 9
filter_index = (4,30)

_, conv = list(onet.features._modules.items())[layer_index]

linear_layer = None
offset = 1
while layer_index + offset < len(onet.features._modules.items()):
    res = list(onet.features._modules.items())[layer_index + offset]
    if isinstance(res[1], torch.nn.Linear):
        layer_name, linear_layer = res
        break
    offset = offset + 1

print(conv)
print(linear_layer)

Conv2d(64, 128, kernel_size=(2, 2), stride=(1, 1))
Linear(in_features=1152, out_features=256, bias=True)


The number of parames per input channel is calculated from the upper conv out channels.
The corresponding to-be-deleted neuron is derived based on filter index and params per input channel 

In [16]:
params_per_input_channel = linear_layer.in_features // conv.out_channels

new_linear_layer = torch.nn.Linear(linear_layer.in_features - len(filter_index)*params_per_input_channel,linear_layer.out_features)

old_weights = linear_layer.weight.data.cpu().numpy()  #i.e. (out_feature x in_feature)
print('linear old weights shape is: ', old_weights.shape)

delete_array = []
for filter in filter_index:
    delete_array += [filter * params_per_input_channel + x for x in range(params_per_input_channel)]
    new_weights = np.delete(old_weights, delete_array, axis=1)  
print('linear new weights shape is: ', new_weights.shape)
new_linear_layer.bias.data = linear_layer.bias.data
new_linear_layer.weight.data = torch.from_numpy(new_weights)

linear old weights shape is:  (256, 1152)
linear new weights shape is:  (256, 1134)
