# Pruning deep neural networks － MobileFaceNet

Prunning MobileFaceNet could be a little tricky. The detailed structure of the model architecture is shown as below. The MobileFaceNet consists of three types of block -- Conv_block, Delpth_wise module and Residual bottleneck. The residual bottleneck with shortcuts (similar to resnet) are used as main building blocks. The Depth-wise module firstly applies 1x1 conv layer with output channels specificed by the expansion factor, then the depthwise and pointwise convolution layers are utilized to reduce computation cost. The residual block is a stack of depth-wise modules with shortcuts applied. 

<img src="imgs/8.png"  width="600" style="float: left;">

In [1]:
from MFN.Base_Model.face_model import *

model = MobileFaceNet(512)
model.load_state_dict(torch.load('MFN/Base_Model/MobileFace_Net', map_location=lambda storage, loc: storage))
model

MobileFaceNet(
  (conv1): Conv_block(
    (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=64)
  )
  (conv2_dw): Conv_block(
    (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=64)
  )
  (conv_23): Depth_Wise(
    (conv): Conv_block(
      (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu): PReLU(num_parameters=128)
    )
    (conv_dw): Conv_block(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=128, bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_st

## Prunning the depth_wise module 

Let's look at the depth_wise module prunning. The depth_wise module consists three convolutional layers. The first 1x1 conv layer is used as a bottleneck. The second depth wise conv has the filters equivalent to "in_channels" and thus each input channel is convolved with its own set of filter. The output feature map size depends on the stride applied from depth_wise conv. The third 1x1 linear conv, also called point-wise conv, is appled to give the desired output channels. For this case, we are able to **only** prune the first 1x1 convolutional layer. The second depth-wise conv needs to prune the same filter indexs since the output's number of channels of 1x1 conv should be the same as the number of depth-wise conv filters. Since the output of depth-wise conv is shrinked. The third liner filter in-channels will need to be reconstructed. The below diagram illustrates the depth_wise module prunning approach.

<img src="imgs/6.png"  width="800" style="float: left;">

## Prunning the residual bottleneck 

The residual bottlenecks are stack of depth-wise module with short-cut connection. In this scenario, each output from depth-wise module should have the same shape since we apply element-wise addition. Therefore, the output feature map of the whole residual bottleneck is the target that we want to prune. Once we decide which **linear conv** channel to be prunned based on the ranking of whole residual bottleneck output, each depth-wise module lincear conv and the upperstream conv's "out_channels" will need to be pruned to maintain the shape. The implication is that if a certain channel of the feature map which is the sum of all depth-wise module output is trivial to the loss, that means each element is trivial. Imagine it for a moment, but hey don't forget to shrink the next conv layer "in_channels" 

<img src="imgs/7.png"  width="800" style="float: left;">

All right ! Hope you understand the basics on MobileFaceNet prunning. Let's look into the detailed code. 

At first, let's rearrange the model to an ordered module list to make it eaiser for indexing and subsequent prunning. 

In [2]:
index = 0
modules = {}
for names, module in list(model._modules.items()):
    if isinstance(module, Depth_Wise):
        for _, module_sub in list(module._modules.items()):
            modules[index] = module_sub
            index += 1

    elif isinstance(module, Residual):
        for i in range(len(module.model)):
            for _, model_sub_sub in list(module.model[i]._modules.items()):
                modules[index] = model_sub_sub
                index += 1          
    else:
        modules[index] = module
        index += 1
modules

{0: Conv_block(
   (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
   (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (prelu): PReLU(num_parameters=64)
 ), 1: Conv_block(
   (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
   (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (prelu): PReLU(num_parameters=64)
 ), 2: Conv_block(
   (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
   (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (prelu): PReLU(num_parameters=128)
 ), 3: Conv_block(
   (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=128, bias=False)
   (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (prelu): PReLU(num_parameters=128)
 ), 4: Linear_block(
   (conv): Conv2d(128, 64, kernel

Import some prunning functions which return new constructed layers based on the to-be-pruned filters. The approaches have been explained on "Prunning MTCNN Tutorial". 

In [3]:
from MFN.utils.prune_MFN import prune_Conv2d, prune_BN, prune_PReLu, prune_linear

Create a MobileFaceNet prunning function as shown below. The "prune_MFN" function will accept the layer_index and filter_index as a tuple, return a pruned model. 

1. if the layer_index represents the Conv block or first 1x1 conv layer in depth_wise module, a while loop will be performed until the next layer is not depth-wise conv layer, the next layer may be a conv layer or fully connected (i.e. layer_index = 47). both cases need to be taken care.
2. if the layer_index is the one which gives the output of whole bottleneck block (i.e. 16, 37, 46), each depth-wise module linear conv will need to be pruned, as well as the upperstream "out_channels" and downstream conv layer "in_channels" for sure.  

In [4]:
def prune_MFN(model, layer_index, *filter_index, use_cuda=True):
    
    # regroup the model modules 
    index = 0
    modules = {}
    for names, module in list(model._modules.items()):
        if isinstance(module, Depth_Wise):
            for _, module_sub in list(module._modules.items()):
                modules[index] = module_sub
                index += 1
                
        elif isinstance(module, Residual):
            for i in range(len(module.model)):
                for _, model_sub_sub in list(module.model[i]._modules.items()):
                    modules[index] = model_sub_sub
                    index += 1          
        else:
            modules[index] = module
            index += 1

    if layer_index == None or filter_index == []:
        return model, modules
            
    if isinstance(modules[layer_index], Conv_block):
        if modules[layer_index].conv.groups != modules[layer_index].conv.in_channels:
        
            conv = modules[layer_index].conv
            bn = modules[layer_index].bn
            prelu = modules[layer_index].prelu
            
            modules[layer_index].conv = prune_Conv2d(conv, filter_index, Next=False, use_cuda = use_cuda)
            modules[layer_index].bn = prune_BN(bn, filter_index, use_cuda = use_cuda)
            modules[layer_index].prelu = prune_PReLu(prelu, filter_index, use_cuda = use_cuda)
            
            next_conv = modules[layer_index+1].conv
            modules[layer_index+1].conv = prune_Conv2d(next_conv, filter_index, Next=True, use_cuda = use_cuda)
            while modules[layer_index+1].conv.groups != 1:
                bn = modules[layer_index+1].bn
                modules[layer_index+1].bn = prune_BN(bn, filter_index, use_cuda = use_cuda)
                if isinstance(modules[layer_index+1], Conv_block):
                    prelu = modules[layer_index+1].prelu
                    modules[layer_index+1].prelu = prune_PReLu(prelu, filter_index, use_cuda = use_cuda)
                layer_index += 1
                if isinstance(modules[layer_index+2], Linear):
                    next_linear = modules[layer_index+2]
                    modules[layer_index+2] = prune_linear(next_linear, next_conv, filter_index, use_cuda = use_cuda)
                else:
                    next_conv = modules[layer_index+1].conv
                    modules[layer_index+1].conv = prune_Conv2d(next_conv, filter_index, Next=True, use_cuda = use_cuda)
                if isinstance(modules[layer_index+1], Flatten):
                    break
    
    if layer_index == 16:
        
        num_blocks = 4
        for i in range(num_blocks+1):
            conv = modules[layer_index - 3*i].conv
            bn = modules[layer_index - 3*i].bn
            
            modules[layer_index - 3*i].conv = prune_Conv2d(conv, filter_index, Next=False, use_cuda = use_cuda)
            modules[layer_index - 3*i].bn = prune_BN(bn, filter_index, use_cuda = use_cuda)
            
            next_conv = modules[layer_index+1-3*i].conv
            modules[layer_index+1-3*i].conv = prune_Conv2d(next_conv, filter_index, Next=True, use_cuda = use_cuda)
            
    if layer_index == 37:
        
        num_blocks = 6
        for i in range(num_blocks+1):
            conv = modules[layer_index - 3*i].conv
            bn = modules[layer_index - 3*i].bn
            
            modules[layer_index - 3*i].conv = prune_Conv2d(conv, filter_index, Next=False, use_cuda = use_cuda)
            modules[layer_index - 3*i].bn = prune_BN(bn, filter_index, use_cuda = use_cuda)
            
            next_conv = modules[layer_index+1-3*i].conv
            modules[layer_index+1-3*i].conv = prune_Conv2d(next_conv, filter_index, Next=True, use_cuda = use_cuda)
    
    if layer_index == 46:
        
        num_blocks = 2
        for i in range(num_blocks+1):
            conv = modules[layer_index - 3*i].conv
            bn = modules[layer_index - 3*i].bn
            
            modules[layer_index - 3*i].conv = prune_Conv2d(conv, filter_index, Next=False, use_cuda = use_cuda)
            modules[layer_index - 3*i].bn = prune_BN(bn, filter_index, use_cuda = use_cuda)
            
            next_conv = modules[layer_index+1-3*i].conv
            modules[layer_index+1-3*i].conv = prune_Conv2d(next_conv, filter_index, Next=True, use_cuda = use_cuda)
             
    index = 0
    for names, module in list(model._modules.items()):
        if isinstance(module, Depth_Wise):
            for _, module_sub in list(module._modules.items()):
                module_sub = modules[index]
                index += 1
                
        elif isinstance(module, Residual):
            for i in range(len(module.model)):
                for _, model_sub_sub in list(module.model[i]._modules.items()):
                    model_sub_sub = modules[index]
                    index += 1          
        else:
            model._modules[names] = modules[index]
            index += 1
    
    return model, modules


Here is an example, you can play with the layer_index and filters_index as you want. But note that not all the layer index is valid. The layer index acceptable by the "prune_MFN" function includes:  
 
1. conv_block
2. first 1x1 conv layer in depth_wise module 
3. the linear conv layer which gives the output of whole bottleneck block. 

In [5]:
layer_index = 16
filter_index = (2,4)

model, module = prune_MFN(model, layer_index, *filter_index, use_cuda=False)
model

MobileFaceNet(
  (conv1): Conv_block(
    (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=64)
  )
  (conv2_dw): Conv_block(
    (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=64)
  )
  (conv_23): Depth_Wise(
    (conv): Conv_block(
      (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu): PReLU(num_parameters=128)
    )
    (conv_dw): Conv_block(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=128, bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_st

## FilterPrunner Class

The FilterPrunner class becomes quite straightforward. Instead of the model, the "forward" function accepts converted module list, iterate each item to capture the intermedia feature maps. The prunning_layers lists all the valid layer index acceptable by "prune_MFN" function. If the index belongs to the prunning_layers list. The grad is hooked and the taylor value is calculated. All other functions are the same as MTCNN prunning.   

In [6]:
class FilterPrunner:
    def __init__(self, model, use_cuda = False):
        self.model = model
        self.reset()
        self.use_cuda = use_cuda

    def reset(self):
        self.filter_ranks = {}

    def forward(self, x):
        self.activations = []
        self.gradients = []
        self.grad_index = 0
        self.activation_to_layer = {}

        activation_index = 0
        Res_layers = (7, 10, 13, 16, 22, 25, 28, 31, 34, 37, 43, 46) # res layers requiring shortcuts 
        prunning_layers = (0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41, 44, 47, 16, 37, 46) # the layers to be prunned
        outputs = {}
        for index, module in self.model.items():
            if isinstance(module, Linear_block) and index in Res_layers:
                x = self.model[index](x) + outputs[index-3]
            else:
                x = self.model[index](x)
                
            outputs[index] = x
            
            if index in prunning_layers:
                x.register_hook(self.compute_rank)
                self.activations.append(x)
                self.activation_to_layer[activation_index] = index # the ith conv2d layer
                activation_index += 1

        return l2_norm(x)

    def compute_rank(self, grad):
        activation_index = len(self.activations) - self.grad_index - 1
        activation = self.activations[activation_index]
        taylor = activation * grad

        # Get the average value for every filter,
        # accross all the other dimensions
        taylor = taylor.mean(dim=(0, 2, 3)).data

        if activation_index not in self.filter_ranks:
            self.filter_ranks[activation_index] = \
                torch.FloatTensor(activation.size(1)).zero_()

            if self.use_cuda:
                self.filter_ranks[activation_index] = self.filter_ranks[activation_index].cuda()

        self.filter_ranks[activation_index] += taylor
        self.grad_index += 1

    def lowest_ranking_filters(self, num):
        data = []
        for i in sorted(self.filter_ranks.keys()):
            for j in range(self.filter_ranks[i].size(0)):
                data.append((self.activation_to_layer[i], j, self.filter_ranks[i][j]))

        return nsmallest(num, data, itemgetter(2))

    def normalize_ranks_per_layer(self):
        for i in self.filter_ranks:
            v = torch.abs(self.filter_ranks[i]).cpu()
            v = v / np.sqrt(torch.sum(v * v))
            self.filter_ranks[i] = v

    def get_prunning_plan(self, num_filters_to_prune):
        filters_to_prune = self.lowest_ranking_filters(num_filters_to_prune)
                
        filters_to_prune_per_layer = {}
        for (l, f, _) in filters_to_prune:
            if l not in filters_to_prune_per_layer:
                filters_to_prune_per_layer[l] = []
            filters_to_prune_per_layer[l].append(f)
    
    
        for l in filters_to_prune_per_layer:
            filters_to_prune_per_layer[l] = sorted(filters_to_prune_per_layer[l])

        return filters_to_prune_per_layer

Let's create a substance of FilterPrunner and demonstrate the prunning process. The Arcface head has to be imported to calcuate the loss

In [7]:
prunner = FilterPrunner(modules, use_cuda = False) 

model.train() 
margin = Arcface(embedding_size=512, classnum=85742,  s=32., m=0.5)
checkpoint = torch.load("MFN/Base_Model/Iter_528000_margin.ckpt", map_location=lambda storage, loc: storage)
margin.load_state_dict(checkpoint['net_state_dict'])
criterion = torch.nn.CrossEntropyLoss()

prunner.reset()

Create one fake input image batch and label batch for loss calucuation and backpropagation for grad 

In [8]:
img = torch.randn(3,3,112,112)
label = torch.Tensor([1,5,10]).type(torch.LongTensor)

model.zero_grad()
    
with torch.set_grad_enabled(True):
    raw_logits = prunner.forward(img)
    output = margin(raw_logits, label)
    loss = criterion(output, label)
    loss.backward()

Normalize and return a dict such that key: layer index, values: filter index 

In [9]:
import numpy as np 
from operator import itemgetter
from heapq import nsmallest
prunner.normalize_ranks_per_layer()
filters_to_prune = prunner.get_prunning_plan(20)
filters_to_prune

{0: [0, 1, 10, 33, 38, 42, 55],
 26: [189],
 20: [45, 76],
 35: [45, 118],
 38: [148],
 29: [63, 166, 220],
 32: [8],
 23: [133, 201],
 44: [133]}