![logo](../../picture/license_header_logo.png)
**Copyright (c) 2020-2021 CertifAI Sdn. Bhd.**

This program is part of OSRFramework. You can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public License
along with this program. If not, see http://www.gnu.org/licenses/.

Authored by: [Jacklyn Lim](mailto:jacklyn.lim@certifai.ai)

## Import Libraries

In [1]:
import copy
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from utils import download_model, download_dataset, load_model_state_dict, load_dataset, load_image, inspect_module, compare_performance

## Define Model Architecture

In [2]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # Note that the input of this layers is depending on your input image sizes
        self.fc1 = nn.Linear(18496, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 3)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

## Download Model and Dataset

In [3]:
# model download
MODEL_DOWNLOAD_PATH = 'https://s3.eu-central-1.wasabisys.com/certifai/deployment-training-labs/models/fruit_classifier_state_dict.pt'
MODEL_STATE_DICT_PATH = '../../resources/model/'
MODEL_FILENAME = 'fruits_image_classification.zip'
download_model(MODEL_DOWNLOAD_PATH, MODEL_STATE_DICT_PATH, MODEL_FILENAME)

# data download
DATA_DOWNLOAD_PATH = "https://s3.eu-central-1.wasabisys.com/certifai/deployment-training-labs/fruits_image_classification-20210604T123547Z-001.zip"
DATA_SAVE_PATH = "../../resources/data/"
DATA_ZIP_FILENAME = "fruits_image_classification.zip"
download_dataset(DATA_DOWNLOAD_PATH, DATA_SAVE_PATH, DATA_ZIP_FILENAME)

model already exists, skipping download
data already exists, skipping download


## Load Model

In [4]:
model = Net()
model = load_model_state_dict(model, MODEL_STATE_DICT_PATH + MODEL_FILENAME)

original_model = copy.deepcopy(model)
single_pruned_model = copy.deepcopy(model)

## Local Pruning

### Prune Model
Pruning acts by removing weight from the parameters and replacing it with a new parameter called ``weight_orig``. ``weight_orig`` stores the unpruned version of 
the tensor. The ``bias`` was not pruned, so it will remain intact.

The pruning mask generated by the pruning technique selected above is saved 
as a module buffer named ``weight_mask`` (i.e. appending ``"_mask"`` to the 
initial parameter ``name``). Hence you can see that initially Module Buffer List is empty but after pruning it contains ``weight_mask``.

In [5]:
def single_prune(model, module):
    prune.l1_unstructured(module, name="weight", amount=0.3)

In [6]:
# inspect module before pruning
print("\033[1mModule Before Pruning \033[0m")
module = single_pruned_model.conv1
inspect_module(module)

# prune the module (conv1 in this example)
single_prune(single_pruned_model, module)

# inspect module after pruning
print("\n\033[1mModule After Pruning \033[0m")
inspect_module(module)
    

[1mModule Before Pruning [0m

Module Parameters:
[('weight', Parameter containing:
tensor([[[[ 7.8413e-02, -6.1813e-02, -8.1539e-03,  1.9616e-02,  3.5149e-02],
          [-1.0102e-01,  1.1059e-02,  6.4063e-02,  6.7392e-02,  5.8826e-02],
          [ 2.3068e-02, -1.0635e-01,  7.8187e-02, -8.9222e-02, -6.6019e-02],
          [ 4.4278e-03,  3.7632e-02,  6.8275e-02, -1.5092e-02, -5.8075e-02],
          [-7.7381e-03,  3.9115e-02, -6.3891e-02, -9.9119e-02,  5.6131e-02]],

         [[-5.3576e-02, -4.2260e-02,  1.1709e-01, -5.5879e-02, -7.3201e-02],
          [ 1.1009e-01, -3.2467e-02,  2.8705e-02,  1.3480e-01, -2.5989e-02],
          [ 1.2869e-01,  1.1275e-01,  1.0845e-01,  5.3049e-02, -3.2747e-02],
          [ 9.2628e-02, -5.9429e-02, -4.8797e-02, -3.7474e-02,  6.8518e-02],
          [ 3.5430e-02,  9.2111e-02,  9.4129e-02,  2.4983e-02,  8.5755e-02]],

         [[-1.0724e-01, -1.3468e-01, -1.1950e-01,  3.8523e-02,  8.1936e-02],
          [-5.4905e-02,  4.0974e-02, -1.0695e-01, -2.7370e-02, -

The pruning techniques implemented in ``torch.nn.utils.prune`` <b>compute the pruned version of the weight (by combining the mask with the original parameter)</b> and <b>store them in the attribute weight</b>. Note, this is no longer a parameter of the module, it is now simply an attribute. 

Notice the pruned weights:

In [7]:
module.weight

tensor([[[[ 0.0784, -0.0618, -0.0000,  0.0000,  0.0000],
          [-0.1010,  0.0000,  0.0641,  0.0674,  0.0588],
          [ 0.0000, -0.1063,  0.0782, -0.0892, -0.0660],
          [ 0.0000,  0.0376,  0.0683, -0.0000, -0.0581],
          [-0.0000,  0.0391, -0.0639, -0.0991,  0.0561]],

         [[-0.0536, -0.0423,  0.1171, -0.0559, -0.0732],
          [ 0.1101, -0.0000,  0.0000,  0.1348, -0.0000],
          [ 0.1287,  0.1128,  0.1085,  0.0530, -0.0000],
          [ 0.0926, -0.0594, -0.0488, -0.0000,  0.0685],
          [ 0.0000,  0.0921,  0.0941,  0.0000,  0.0858]],

         [[-0.1072, -0.1347, -0.1195,  0.0385,  0.0819],
          [-0.0549,  0.0410, -0.1070, -0.0000, -0.0000],
          [-0.0564, -0.0792,  0.0000, -0.0000, -0.0478],
          [ 0.0567, -0.1115,  0.0512,  0.0782, -0.0959],
          [-0.1335, -0.0000, -0.1121,  0.0000, -0.1167]]],


        [[[-0.0917,  0.0000, -0.0000, -0.1037,  0.0437],
          [ 0.0394,  0.0000,  0.0740, -0.0477,  0.1042],
          [-0.0893,  0.

Lastly, Pruning is applied prior to each forward pass using PyTorch's
``forward_pre_hooks``.

Specifically, when the ``module`` is pruned, as we 
have done here, it will acquire a <b>``forward_pre_hook`` for each parameter 
associated with it that gets pruned</b>. That means pruning for each parameter 
associated with the pruned weights will be executed as well before every forward pass.

In this case, since we have so far only pruned the original parameter named ``weight``, only one hook will be
present.

In [8]:
for hook in module._forward_pre_hooks.values():
    print(hook)

<torch.nn.utils.prune.L1Unstructured object at 0x000001CF63913340>


### Serializing a Pruned Model

All relevant tensors, including the ``weight_mask`` in the buffers and the original parameters
used to compute the pruned tensors are stored in the model's ``state_dict`` 
and can therefore be easily serialized and saved, if needed.

In [9]:
print(single_pruned_model.state_dict().keys())

odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])


### Remove Pruning Re-parametrization 

To make the pruning permanent (i.e.: zero out the parameters), we need to remove the re-parametrization in terms
of ``weight_orig`` and ``weight_mask``, and remove the ``forward_pre_hook``, we can use the ``remove`` functionality from ``torch.nn.utils.prune``. 

What we are doing here is reassigning the attribute ``weight`` to the model parameters, in its pruned version.

In [10]:
# inspect module before re-parametrization
print("\n\033[1mModule Before Re-parametrization \033[0m")
inspect_module(module)

# re-parametrization
prune.remove(module, 'weight')

# inspect module after re-parametrization
print("\n\033[1mModule After Re-parametrization \033[0m")
inspect_module(module)


[1mModule Before Re-parametrization [0m

Module Parameters:
[('bias', Parameter containing:
tensor([-0.0594, -0.0101, -0.5172, -0.0776, -0.1052, -0.1574],
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 7.8413e-02, -6.1813e-02, -8.1539e-03,  1.9616e-02,  3.5149e-02],
          [-1.0102e-01,  1.1059e-02,  6.4063e-02,  6.7392e-02,  5.8826e-02],
          [ 2.3068e-02, -1.0635e-01,  7.8187e-02, -8.9222e-02, -6.6019e-02],
          [ 4.4278e-03,  3.7632e-02,  6.8275e-02, -1.5092e-02, -5.8075e-02],
          [-7.7381e-03,  3.9115e-02, -6.3891e-02, -9.9119e-02,  5.6131e-02]],

         [[-5.3576e-02, -4.2260e-02,  1.1709e-01, -5.5879e-02, -7.3201e-02],
          [ 1.1009e-01, -3.2467e-02,  2.8705e-02,  1.3480e-01, -2.5989e-02],
          [ 1.2869e-01,  1.1275e-01,  1.0845e-01,  5.3049e-02, -3.2747e-02],
          [ 9.2628e-02, -5.9429e-02, -4.8797e-02, -3.7474e-02,  6.8518e-02],
          [ 3.5430e-02,  9.2111e-02,  9.4129e-02,  2.4983e-02,  8.5755e-02]],

 

### Compare Model Performance

In [11]:
INFERENCE_IMAGE_PATH = "../../resources/data/fruits_image_classification/test/apple/image1.jpg"
TEST_DATASET_ROOTDIR = "../../resources/data/fruits_image_classification/test"

# load image
inference_image = load_image(INFERENCE_IMAGE_PATH)

# load test dataset
test_dataloader = load_dataset(TEST_DATASET_ROOTDIR)

compare_performance(original_model, single_pruned_model, "original_model", "single_pruned_model", inference_image, test_dataloader)

Comparing size of models
model:  original_model  	 Size (KB): 8935.103
model:  single_pruned_model  	 Size (KB): 8935.103
1.00 times smaller

Comparing latency of models
model:  original_model  	 prediction time: 0.003998756408691406s
model:  single_pruned_model  	 prediction time: 0.0030012130737304688s

Comparing accuracy of models
model:  original_model  	 Test Accuracy: 0.74
model:  single_pruned_model  	 Test Accuracy: 0.74


## Global Pruning

So far, we only looked at what is usually referred to as "local" pruning, i.e. the practice of **pruning tensors in a model one by one**, by comparing the statistics (weight magnitude, activation, gradient, etc.) of each entry exclusively to the other entries in that tensor.


However, a common and perhaps more powerful technique is to **prune the model all at once, by removing (for example) the lowest 20% of connections across the whole model**, instead of removing the lowest 20% of connections in each layer.

### Global Pruning 

In [12]:
global_pruned_model = copy.deepcopy(model)

parameters_to_prune = (
    (global_pruned_model.conv1, 'weight'),
    (global_pruned_model.conv2, 'weight'),
    (global_pruned_model.fc1, 'weight'),
    (global_pruned_model.fc2, 'weight'),
    (global_pruned_model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured, 
    amount=0.3,
)

### Re-parametrization

In [13]:
# re-parametrization
for module, parameter_name in parameters_to_prune:
    prune.remove(module, parameter_name)

### Compare Model Performance

In [14]:
compare_performance(original_model, global_pruned_model, "original_model", "global_pruned_model", inference_image, test_dataloader)

Comparing size of models
model:  original_model  	 Size (KB): 8935.103
model:  global_pruned_model  	 Size (KB): 8935.103
1.00 times smaller

Comparing latency of models
model:  original_model  	 prediction time: 0.003000497817993164s
model:  global_pruned_model  	 prediction time: 0.0020029544830322266s

Comparing accuracy of models
model:  original_model  	 Test Accuracy: 0.74
model:  global_pruned_model  	 Test Accuracy: 0.74


## Why is the pruned model performing the same as the unpruned model?

Since Pytorch is representing the pruned model (Sparse model) in the same architecture as the unpruned model (dense model), hence the number of parameters are still the same, only the values of the parameters have been zero-ed out. 

However, Pytorch community is actively working on the support of converting the sparse neural networks to use sparse tensors. Once it is supported, we would be able to see the performance difference in running a sparse neural network after pruning comparing to its original dense neural network before pruning.

In [15]:
# sanity check to see if pruning has been done successfully
original_model.conv1.weight  # unpruned model

Parameter containing:
tensor([[[[ 7.8413e-02, -6.1813e-02, -8.1539e-03,  1.9616e-02,  3.5149e-02],
          [-1.0102e-01,  1.1059e-02,  6.4063e-02,  6.7392e-02,  5.8826e-02],
          [ 2.3068e-02, -1.0635e-01,  7.8187e-02, -8.9222e-02, -6.6019e-02],
          [ 4.4278e-03,  3.7632e-02,  6.8275e-02, -1.5092e-02, -5.8075e-02],
          [-7.7381e-03,  3.9115e-02, -6.3891e-02, -9.9119e-02,  5.6131e-02]],

         [[-5.3576e-02, -4.2260e-02,  1.1709e-01, -5.5879e-02, -7.3201e-02],
          [ 1.1009e-01, -3.2467e-02,  2.8705e-02,  1.3480e-01, -2.5989e-02],
          [ 1.2869e-01,  1.1275e-01,  1.0845e-01,  5.3049e-02, -3.2747e-02],
          [ 9.2628e-02, -5.9429e-02, -4.8797e-02, -3.7474e-02,  6.8518e-02],
          [ 3.5430e-02,  9.2111e-02,  9.4129e-02,  2.4983e-02,  8.5755e-02]],

         [[-1.0724e-01, -1.3468e-01, -1.1950e-01,  3.8523e-02,  8.1936e-02],
          [-5.4905e-02,  4.0974e-02, -1.0695e-01, -2.7370e-02, -2.5292e-02],
          [-5.6430e-02, -7.9165e-02,  2.9678e-02, 

In [16]:
# sanity check to see if pruning has been done successfully
global_pruned_model.conv1.weight  # pruned model

Parameter containing:
tensor([[[[ 0.0784, -0.0618, -0.0082,  0.0196,  0.0351],
          [-0.1010,  0.0111,  0.0641,  0.0674,  0.0588],
          [ 0.0231, -0.1063,  0.0782, -0.0892, -0.0660],
          [ 0.0044,  0.0376,  0.0683, -0.0151, -0.0581],
          [-0.0077,  0.0391, -0.0639, -0.0991,  0.0561]],

         [[-0.0536, -0.0423,  0.1171, -0.0559, -0.0732],
          [ 0.1101, -0.0325,  0.0287,  0.1348, -0.0260],
          [ 0.1287,  0.1128,  0.1085,  0.0530, -0.0327],
          [ 0.0926, -0.0594, -0.0488, -0.0375,  0.0685],
          [ 0.0354,  0.0921,  0.0941,  0.0250,  0.0858]],

         [[-0.1072, -0.1347, -0.1195,  0.0385,  0.0819],
          [-0.0549,  0.0410, -0.1070, -0.0274, -0.0253],
          [-0.0564, -0.0792,  0.0297, -0.0350, -0.0478],
          [ 0.0567, -0.1115,  0.0512,  0.0782, -0.0959],
          [-0.1335, -0.0084, -0.1121,  0.0244, -0.1167]]],


        [[[-0.0917,  0.0297, -0.0311, -0.1037,  0.0437],
          [ 0.0394,  0.0142,  0.0740, -0.0477,  0.1042],
 

## Compare Sparsity

GlobalSparsity = (Sum of elements in weight matrix of the module where element == 0 in all modules) / (Number of elements in the weight matrix in all modules)


In [17]:
def calculate_global_sparsity(model):
    return 100. * float(
        torch.sum(model.conv1.weight == 0)
        + torch.sum(model.conv2.weight == 0)
        + torch.sum(model.fc1.weight == 0)
        + torch.sum(model.fc2.weight == 0)
        + torch.sum(model.fc3.weight == 0)
    ) / float(
        model.conv1.weight.nelement()
        + model.conv2.weight.nelement()
        + model.fc1.weight.nelement()
        + model.fc2.weight.nelement()
        + model.fc3.weight.nelement()
    )

Remember we are measuring the sparsity of the model, not density. We can see that our globally pruned model is around 30% sparser than the unpruned model. 

In [18]:
print("Global sparsity for unpruned model: {:.2f}%".format(calculate_global_sparsity(original_model)))
print("Global sparsity for locally pruned model: {:.2f}%".format(calculate_global_sparsity(single_pruned_model)))
print("Global sparsity for globally pruned model: {:.2f}%".format(calculate_global_sparsity(global_pruned_model)))

Global sparsity for unpruned model: 0.00%
Global sparsity for locally pruned model: 0.01%
Global sparsity for globally pruned model: 30.00%


## Additional Notes

As of June 2021, Pytorch is not supporting the conversion of the sparse neural networks to use sparse tensor yet. As mentioned previously, Pytorch is representing the pruned model (sparse model) in the same architecture as the unpruned model (dense model), hence the number of parameters are still the same, only the values of the parameters have been zero-ed out.

Hence this might not decrease the model size nor increase inference speed.

However, we can see the promising result of optimising the model with model pruning + model quantization as shown by [Tensorflow](https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras) and it created a 
- 3x smaller model from pruning 
- 10x smaller model from pruning + quantization

while the accuracy still persists.

## References

- [Pruning for Neural Networks by Lei Mao](https://leimao.github.io/article/Neural-Networks-Pruning/)
- [PyTorch Pruning by Lei Mao](https://leimao.github.io/blog/PyTorch-Pruning/)
- [Pruning Tutorial by Pytorch](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html)