<a href="https://colab.research.google.com/github/anujdutt9/PyTorch-DeepLearning/blob/master/ML_Model_Pruning_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ML Model Pruning using PyTorch

Ref.: https://pytorch.org/tutorials/intermediate/pruning_tutorial.html#extending-torch-nn-utils-prune-with-custom-pruning-functions

In [1]:
# Import Dependencies
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [2]:
# Check PyTorch Version
torch.__version__

'1.5.0+cu101'

In [3]:
# Create a simple LeNET Model Architecture
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [4]:
# Set device to train the model on
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
# Model Initialization
model = LeNet().to(device=device)
print(model)

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


#Inspect a Module in the Defined Model

Here, we are loading a single layer of the defined ML model, calling the Conv1 layer as a module, with it's initial **Un-pruned weights**.

Then, we are just printing out the un-pruned weights and bias values for the Conv1 layer.

**NOTE:** See how the name **'weight'** and **'bias'** appears for the values in the Conv1 layer.

In [6]:
# Take first layer and print out the Un-Pruned parameters i.e.
# weights & biases
module = model.conv1
print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[[[-0.2416,  0.1972,  0.0812],
          [ 0.2535,  0.1136, -0.3117],
          [ 0.2692, -0.0886, -0.2302]]],


        [[[-0.2093, -0.2896,  0.0456],
          [-0.0730, -0.0114,  0.2277],
          [ 0.1133,  0.3246, -0.0525]]],


        [[[-0.1141, -0.0150, -0.1813],
          [ 0.0240,  0.1568,  0.1943],
          [ 0.1124,  0.0140, -0.0316]]],


        [[[-0.0146,  0.1459, -0.1749],
          [-0.1287,  0.1655, -0.1921],
          [ 0.0227, -0.1735,  0.2135]]],


        [[[-0.1326,  0.0777, -0.1734],
          [ 0.2746, -0.0787,  0.1938],
          [ 0.2148,  0.1972,  0.0917]]],


        [[[ 0.1897,  0.1422, -0.1531],
          [-0.2701,  0.0163,  0.1772],
          [ 0.2436, -0.1219, -0.3129]]]], requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.1209,  0.3181,  0.0193, -0.2359,  0.2840, -0.2947],
       requires_grad=True))]


In [7]:
# This module has no named buffers for now
print(list(module.named_buffers()))

[]


In [8]:
# Take first layer and print out the Un-pruned parameters i.e.
# weights & biases
module_fc = model.fc1
print(list(module_fc.named_parameters()))

[('weight', Parameter containing:
tensor([[-0.0435,  0.0392, -0.0189,  ...,  0.0163, -0.0469,  0.0444],
        [-0.0127,  0.0344,  0.0455,  ..., -0.0046, -0.0064,  0.0457],
        [ 0.0081,  0.0108,  0.0210,  ..., -0.0358,  0.0033, -0.0339],
        ...,
        [ 0.0372, -0.0162, -0.0239,  ...,  0.0357,  0.0484,  0.0002],
        [-0.0243,  0.0014, -0.0417,  ..., -0.0239,  0.0293,  0.0421],
        [-0.0185, -0.0418, -0.0248,  ..., -0.0375, -0.0463,  0.0091]],
       requires_grad=True)), ('bias', Parameter containing:
tensor([-0.0302, -0.0452,  0.0378, -0.0282, -0.0371, -0.0245, -0.0288, -0.0206,
         0.0065, -0.0115, -0.0043, -0.0088, -0.0496, -0.0284,  0.0346,  0.0362,
         0.0363,  0.0235,  0.0384, -0.0388,  0.0079, -0.0074,  0.0469, -0.0407,
         0.0460, -0.0468,  0.0153,  0.0010,  0.0115, -0.0128,  0.0452, -0.0266,
        -0.0278,  0.0468,  0.0331,  0.0175, -0.0459, -0.0147, -0.0239,  0.0475,
         0.0215, -0.0402,  0.0439, -0.0340, -0.0342,  0.0041, -0.0137, -

# Pruning the Module

In this, we'll prune the first layer i.e. the Conv1 module of the model.

## Steps:

1. Firstly, we'll select a pruning technique among those available in **torch.nn.utils.prune**.

2. Then, **specify the module** i.e. which layer in the model you want to prune and the **name of the parameter to prune**, i.e. **un-pruned layer weight or bias**, within that module. 

3. Finally, using the adequate keyword arguments required by the selected pruning technique, specify the pruning parameters.


In this example, we will **prune at random 30% of the connections** in the **parameter named weight in the conv1 layer**.

The pruning technique function takes the following arguments:

1. **module:** is passed as the first argument to the function.

2. **name:** identifies the parameter within that module using its string identifier i.e 'weight' or 'bias'.

3. **amount:** indicates either the percentage of connections to prune (if it is a float between 0. and 1.), or the absolute number of connections to prune (if it is a non-negative integer).

In [9]:
# Prune at random 30% of the weights in Conv1 layer
prune.random_unstructured(module, name="weight", amount=0.3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

## Renaming the Original Layer Parameters

Pruning acts by removing weight from the parameters and replacing it with a new parameter called **weight_orig (i.e. appending "_orig" to the initial parameter name)**. 

**weight_orig** stores the **un-pruned version of the tensor (weight or bias)**. The bias was not pruned, so it will remain intact.

In [10]:
# See how after applying Pruning, the Original Weights of the Model are
# stored and renamed as 'weight_orig'.
# The 'weight_orig' are un-pruned weights of the layer.
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([ 0.1209,  0.3181,  0.0193, -0.2359,  0.2840, -0.2947],
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.2416,  0.1972,  0.0812],
          [ 0.2535,  0.1136, -0.3117],
          [ 0.2692, -0.0886, -0.2302]]],


        [[[-0.2093, -0.2896,  0.0456],
          [-0.0730, -0.0114,  0.2277],
          [ 0.1133,  0.3246, -0.0525]]],


        [[[-0.1141, -0.0150, -0.1813],
          [ 0.0240,  0.1568,  0.1943],
          [ 0.1124,  0.0140, -0.0316]]],


        [[[-0.0146,  0.1459, -0.1749],
          [-0.1287,  0.1655, -0.1921],
          [ 0.0227, -0.1735,  0.2135]]],


        [[[-0.1326,  0.0777, -0.1734],
          [ 0.2746, -0.0787,  0.1938],
          [ 0.2148,  0.1972,  0.0917]]],


        [[[ 0.1897,  0.1422, -0.1531],
          [-0.2701,  0.0163,  0.1772],
          [ 0.2436, -0.1219, -0.3129]]]], requires_grad=True))]


## Weight Mask

When we call prune method on a layer/module, it renames the original weights and creates a mask, with the shape same as the original un-pruned weights, in the module buffer.

The pruning mask generated by the pruning technique selected above is saved as a module buffer named **weight_mask** (i.e. appending **"_mask"** to the initial parameter name).

This **'weight_mask'** is later on **used to zero out the corresponding values of the original weight tensor ('weight_orig')**. So, what remains is the values in the 'weight_orig' corresponding to the "1" in the 'weight_mask'. Hence, the weights are pruned.

In [11]:
# The weight_mask tells which weight values to prune and which to leave.
# All weights in "weight_orig" corresponding to a "1" in the mask remain intact, 
# rest all weights corresponding to "0" in mask are pruned.
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[0., 1., 0.],
          [1., 1., 1.],
          [1., 1., 0.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[0., 0., 1.],
          [1., 1., 0.],
          [0., 0., 0.]]],


        [[[1., 1., 1.],
          [0., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 0., 1.],
          [0., 1., 1.]]],


        [[[1., 0., 1.],
          [0., 1., 0.],
          [1., 1., 1.]]]]))]


**For the forward pass to work without modification, the weight attribute needs to exist.** The pruning techniques implemented in torch.nn.utils.prune **compute the pruned version of the weight (by combining the mask with the original parameter) and store them in the attribute weight.** Note, this is no longer a parameter of the module, it is now simply an attribute.

In [12]:
# Pruned Weights for Conv1 Layer
# Weights set to "0" for places where mask has a "0"
print(module.weight)

tensor([[[[-0.0000,  0.1972,  0.0000],
          [ 0.2535,  0.1136, -0.3117],
          [ 0.2692, -0.0886, -0.0000]]],


        [[[-0.2093, -0.2896,  0.0456],
          [-0.0730, -0.0114,  0.2277],
          [ 0.1133,  0.3246, -0.0525]]],


        [[[-0.0000, -0.0000, -0.1813],
          [ 0.0240,  0.1568,  0.0000],
          [ 0.0000,  0.0000, -0.0000]]],


        [[[-0.0146,  0.1459, -0.1749],
          [-0.0000,  0.0000, -0.1921],
          [ 0.0227, -0.1735,  0.2135]]],


        [[[-0.1326,  0.0777, -0.1734],
          [ 0.2746, -0.0000,  0.1938],
          [ 0.0000,  0.1972,  0.0917]]],


        [[[ 0.1897,  0.0000, -0.1531],
          [-0.0000,  0.0163,  0.0000],
          [ 0.2436, -0.1219, -0.3129]]]], grad_fn=<MulBackward0>)


In [13]:
# This verifies that the original parameters are not changed unl you do so
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([ 0.1209,  0.3181,  0.0193, -0.2359,  0.2840, -0.2947],
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.2416,  0.1972,  0.0812],
          [ 0.2535,  0.1136, -0.3117],
          [ 0.2692, -0.0886, -0.2302]]],


        [[[-0.2093, -0.2896,  0.0456],
          [-0.0730, -0.0114,  0.2277],
          [ 0.1133,  0.3246, -0.0525]]],


        [[[-0.1141, -0.0150, -0.1813],
          [ 0.0240,  0.1568,  0.1943],
          [ 0.1124,  0.0140, -0.0316]]],


        [[[-0.0146,  0.1459, -0.1749],
          [-0.1287,  0.1655, -0.1921],
          [ 0.0227, -0.1735,  0.2135]]],


        [[[-0.1326,  0.0777, -0.1734],
          [ 0.2746, -0.0787,  0.1938],
          [ 0.2148,  0.1972,  0.0917]]],


        [[[ 0.1897,  0.1422, -0.1531],
          [-0.2701,  0.0163,  0.1772],
          [ 0.2436, -0.1219, -0.3129]]]], requires_grad=True))]


Finally, **pruning is applied prior to each forward pass using PyTorch’s forward_pre_hooks**. Specifically, **when the module is pruned**, as we have done here, **it will acquire a forward_pre_hook for each parameter associated with it that gets pruned**. In this case, since we have so far only pruned the original parameter named weight, only one hook will be present.

In [14]:
# Forward Pass pre Hook
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f2ec9937ba8>)])


For completeness, we can now prune the bias too, to see how the parameters, buffers, hooks, and attributes of the module change. Just for the sake of trying out another pruning technique, here **we prune the 3 smallest entries in the bias by L1 norm**, as implemented in the **l1_unstructured pruning** function.

In [15]:
# Bias Pruning
prune.l1_unstructured(module, name="bias", amount=3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

We now expect the named parameters to include both **weight_orig (from before) and bias_orig**. The buffers will include **weight_mask and bias_mask**. The pruned versions of the two tensors will exist as module attributes, and the module will now have two forward_pre_hooks.

In [16]:
# Module Original Named Parameters
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[-0.2416,  0.1972,  0.0812],
          [ 0.2535,  0.1136, -0.3117],
          [ 0.2692, -0.0886, -0.2302]]],


        [[[-0.2093, -0.2896,  0.0456],
          [-0.0730, -0.0114,  0.2277],
          [ 0.1133,  0.3246, -0.0525]]],


        [[[-0.1141, -0.0150, -0.1813],
          [ 0.0240,  0.1568,  0.1943],
          [ 0.1124,  0.0140, -0.0316]]],


        [[[-0.0146,  0.1459, -0.1749],
          [-0.1287,  0.1655, -0.1921],
          [ 0.0227, -0.1735,  0.2135]]],


        [[[-0.1326,  0.0777, -0.1734],
          [ 0.2746, -0.0787,  0.1938],
          [ 0.2148,  0.1972,  0.0917]]],


        [[[ 0.1897,  0.1422, -0.1531],
          [-0.2701,  0.0163,  0.1772],
          [ 0.2436, -0.1219, -0.3129]]]], requires_grad=True)), ('bias_orig', Parameter containing:
tensor([ 0.1209,  0.3181,  0.0193, -0.2359,  0.2840, -0.2947],
       requires_grad=True))]


In [17]:
# The weight_mask tells which weight values to prune and which to leave.
# All weights in "weight_orig" corresponding to a "1" in the mask remain intact, 
# rest all weights corresponding to "0" in mask are pruned.
# Same for bias values in "bias_orig"
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[0., 1., 0.],
          [1., 1., 1.],
          [1., 1., 0.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[0., 0., 1.],
          [1., 1., 0.],
          [0., 0., 0.]]],


        [[[1., 1., 1.],
          [0., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 0., 1.],
          [0., 1., 1.]]],


        [[[1., 0., 1.],
          [0., 1., 0.],
          [1., 1., 1.]]]])), ('bias_mask', tensor([0., 1., 0., 0., 1., 1.]))]


In [18]:
# Pruned Bias values for Conv1 Layer
# Bias set to "0" for places where mask has a "0"
print(module.bias)

tensor([ 0.0000,  0.3181,  0.0000, -0.0000,  0.2840, -0.2947],
       grad_fn=<MulBackward0>)


In [19]:
# Forward Pass pre Hook
# There should be 2 hooks, one for weights, other for bias
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f2ec9937ba8>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7f2ec994f6a0>)])


# Iterative Pruning

The same parameter in a module can be pruned multiple times, with the effect of the various pruning calls being equal to the combination of the various masks applied in series. 

The combination of a new mask with the old mask is handled by the **PruningContainer**’s **compute_mask method**.

Say, for example, that we now want to further prune **module.weight**, this time **using structured pruning along the 0th axis of the tensor** (the 0th axis corresponds to the output channels of the convolutional layer and has dimensionality 6 for conv1), based on the channels’ L2 norm. 

This can be achieved using the **ln_structured** function, with **n=2 and dim=0**.

In [20]:
# Iterativively Prune the Model i.e.
# Serially apply multiple masks to the Model Layer (weight/bias)

# Pruning Parameters
# module: Model Layer
# name: Model layer parameter to prune
# amount: Percentage of pruning to perform for this layer
# n: Matrix Norm to use. Ref: https://pytorch.org/docs/master/generated/torch.norm.html#torch.norm
# dim: index of dim along which to prune the values in the tensor
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

In [21]:
# As we can verify, this will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
print(module.weight)

tensor([[[[-0.0000,  0.1972,  0.0000],
          [ 0.2535,  0.1136, -0.3117],
          [ 0.2692, -0.0886, -0.0000]]],


        [[[-0.2093, -0.2896,  0.0456],
          [-0.0730, -0.0114,  0.2277],
          [ 0.1133,  0.3246, -0.0525]]],


        [[[-0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000]]],


        [[[-0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000]]],


        [[[-0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]],


        [[[ 0.1897,  0.0000, -0.1531],
          [-0.0000,  0.0163,  0.0000],
          [ 0.2436, -0.1219, -0.3129]]]], grad_fn=<MulBackward0>)


The corresponding hook will now be of type **torch.nn.utils.prune.PruningContainer**, and will store the history of pruning applied to the weight parameter.

In [22]:
for hook in module._forward_pre_hooks.values():
  # Select the correct hook
  if hook._tensor_name == "weight":
    break

# Puning History in the Container
print(list(hook))

[<torch.nn.utils.prune.RandomUnstructured object at 0x7f2ec9937ba8>, <torch.nn.utils.prune.LnStructured object at 0x7f2ec994f390>]


# Serializing Pruned Model

All relevant tensors, including the **mask buffers** and the original parameters used to compute the pruned tensors are stored in the model’s **state_dict** and can therefore be easily serialized and saved, if needed.

In [23]:
print(model.state_dict().keys())

odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])


# Remove pruning re-parametrization

To make the pruning permanent, remove the re-parametrization in terms of weight_orig and weight_mask, and remove the forward_pre_hook, we can use the remove functionality from torch.nn.utils.prune. 

Note that this doesn’t undo the pruning, as if it never happened. It simply makes it permanent, instead, by reassigning the parameter weight to the model parameters, in its pruned version.

## Prior to removing the re-parametrization

In [24]:
# Layer Parameters Prior to removing the re-parametrization
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[-0.2416,  0.1972,  0.0812],
          [ 0.2535,  0.1136, -0.3117],
          [ 0.2692, -0.0886, -0.2302]]],


        [[[-0.2093, -0.2896,  0.0456],
          [-0.0730, -0.0114,  0.2277],
          [ 0.1133,  0.3246, -0.0525]]],


        [[[-0.1141, -0.0150, -0.1813],
          [ 0.0240,  0.1568,  0.1943],
          [ 0.1124,  0.0140, -0.0316]]],


        [[[-0.0146,  0.1459, -0.1749],
          [-0.1287,  0.1655, -0.1921],
          [ 0.0227, -0.1735,  0.2135]]],


        [[[-0.1326,  0.0777, -0.1734],
          [ 0.2746, -0.0787,  0.1938],
          [ 0.2148,  0.1972,  0.0917]]],


        [[[ 0.1897,  0.1422, -0.1531],
          [-0.2701,  0.0163,  0.1772],
          [ 0.2436, -0.1219, -0.3129]]]], requires_grad=True)), ('bias_orig', Parameter containing:
tensor([ 0.1209,  0.3181,  0.0193, -0.2359,  0.2840, -0.2947],
       requires_grad=True))]


In [25]:
# Pruning Mask Prior to removing the re-parametrization
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[0., 1., 0.],
          [1., 1., 1.],
          [1., 1., 0.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[1., 0., 1.],
          [0., 1., 0.],
          [1., 1., 1.]]]])), ('bias_mask', tensor([0., 1., 0., 0., 1., 1.]))]


In [26]:
# Pruned Layer Weights before re-parametrization
print(module.weight)

tensor([[[[-0.0000,  0.1972,  0.0000],
          [ 0.2535,  0.1136, -0.3117],
          [ 0.2692, -0.0886, -0.0000]]],


        [[[-0.2093, -0.2896,  0.0456],
          [-0.0730, -0.0114,  0.2277],
          [ 0.1133,  0.3246, -0.0525]]],


        [[[-0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000]]],


        [[[-0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000]]],


        [[[-0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]],


        [[[ 0.1897,  0.0000, -0.1531],
          [-0.0000,  0.0163,  0.0000],
          [ 0.2436, -0.1219, -0.3129]]]], grad_fn=<MulBackward0>)


## After removing the Layer re-parametrization Values

In [27]:
# Removing the Pruned Weights sets the 'weight_orig' = 'weight'
# where,
# weight_orig: un-pruned layer parameters
# weight: pruned layer parameters
prune.remove(module, 'weight')

# Note how 'weight_orig' values change to that of 'weight'
print(list(module.named_parameters()))

[('bias_orig', Parameter containing:
tensor([ 0.1209,  0.3181,  0.0193, -0.2359,  0.2840, -0.2947],
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[-0.0000,  0.1972,  0.0000],
          [ 0.2535,  0.1136, -0.3117],
          [ 0.2692, -0.0886, -0.0000]]],


        [[[-0.2093, -0.2896,  0.0456],
          [-0.0730, -0.0114,  0.2277],
          [ 0.1133,  0.3246, -0.0525]]],


        [[[-0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000]]],


        [[[-0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000]]],


        [[[-0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]],


        [[[ 0.1897,  0.0000, -0.1531],
          [-0.0000,  0.0163,  0.0000],
          [ 0.2436, -0.1219, -0.3129]]]], requires_grad=True))]


In [28]:
# After Layer re-parametrization, the Pruning Mask remains the same
print(list(module.named_buffers()))

[('bias_mask', tensor([0., 1., 0., 0., 1., 1.]))]


# Pruning Multiple Parameters in a Model

In [29]:
# Instantiate the Model
new_model = LeNet()

for name, module in new_model.named_modules():
  # Prune the Conv2D layer weights by 20%
  if isinstance(module, torch.nn.Conv2d):
    prune.l1_unstructured(module=module, name='weight', amount=0.2)
  
  # Prune the Fully Connected layer weights by 40%
  if isinstance(module, torch.nn.Linear):
    prune.l1_unstructured(module=module, name='weight', amount=0.4)

# Verify that all masks exist
print(dict(new_model.named_buffers()).keys())

dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])


In [30]:
# See that if the name of weight changed to weight_orig after pruning
module1 = new_model.conv1
print(list(module1.named_parameters()))

[('bias', Parameter containing:
tensor([ 0.2191, -0.2709, -0.1657, -0.1700, -0.1311,  0.1843],
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.1706,  0.2018,  0.0024],
          [-0.1397,  0.2210, -0.2712],
          [ 0.0319, -0.1529,  0.1179]]],


        [[[-0.2569, -0.1396, -0.3133],
          [-0.3023, -0.1992,  0.2522],
          [-0.1212, -0.2803, -0.0088]]],


        [[[-0.1832, -0.1292,  0.2746],
          [-0.2691, -0.2559,  0.2355],
          [-0.0645, -0.0838, -0.2616]]],


        [[[ 0.3041,  0.2713, -0.1414],
          [-0.2913, -0.0827, -0.1309],
          [-0.2568,  0.0478, -0.0426]]],


        [[[ 0.1748,  0.3034, -0.1001],
          [-0.0638, -0.0835,  0.1669],
          [ 0.0770,  0.1859, -0.1979]]],


        [[[-0.2213, -0.1200, -0.2784],
          [ 0.1652,  0.3087,  0.2424],
          [-0.0592, -0.1538,  0.2409]]]], requires_grad=True))]


In [31]:
# Pruned Weights for Module1
print(module1.weight)

tensor([[[[-0.1706,  0.2018,  0.0000],
          [-0.1397,  0.2210, -0.2712],
          [ 0.0000, -0.1529,  0.1179]]],


        [[[-0.2569, -0.1396, -0.3133],
          [-0.3023, -0.1992,  0.2522],
          [-0.1212, -0.2803, -0.0000]]],


        [[[-0.1832, -0.1292,  0.2746],
          [-0.2691, -0.2559,  0.2355],
          [-0.0000, -0.0838, -0.2616]]],


        [[[ 0.3041,  0.2713, -0.1414],
          [-0.2913, -0.0000, -0.1309],
          [-0.2568,  0.0000, -0.0000]]],


        [[[ 0.1748,  0.3034, -0.1001],
          [-0.0000, -0.0000,  0.1669],
          [ 0.0000,  0.1859, -0.1979]]],


        [[[-0.2213, -0.1200, -0.2784],
          [ 0.1652,  0.3087,  0.2424],
          [-0.0000, -0.1538,  0.2409]]]], grad_fn=<MulBackward0>)


In [32]:
# Check the same for Fully Connected Layer
module5 = new_model.fc3
print(list(module5.named_parameters()))

[('bias', Parameter containing:
tensor([ 0.0073,  0.0726, -0.0494, -0.0478,  0.0205,  0.0878, -0.0457,  0.0783,
        -0.0423, -0.0953], requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[-0.0121,  0.0096,  0.0407,  0.0687, -0.0731, -0.0717,  0.0151, -0.0051,
          0.0559,  0.0077, -0.0048,  0.1006, -0.0579, -0.0172, -0.0556, -0.0644,
          0.0444, -0.0422, -0.0488, -0.0573,  0.0707, -0.0927,  0.0407,  0.0150,
          0.1050,  0.0166, -0.0963, -0.0201,  0.0359, -0.0087, -0.0280,  0.0911,
         -0.0363,  0.0627, -0.0203, -0.0232, -0.0317, -0.1035,  0.0670,  0.0689,
          0.0956,  0.0087,  0.0932,  0.0392, -0.0263, -0.0063,  0.0969,  0.0067,
         -0.0283, -0.0926, -0.1000,  0.0536,  0.1029,  0.0457, -0.0223, -0.0838,
         -0.0846, -0.0883, -0.0229, -0.0800, -0.0601, -0.0663,  0.0160, -0.0242,
         -0.0375,  0.0216, -0.0981, -0.0271,  0.0756, -0.0685, -0.0093, -0.0656,
          0.0446, -0.0808,  0.0861, -0.0299,  0.0736, -0.0383, -0.1009, 

In [33]:
# Pruned Weights for Module5 i.e. last FC layer
print(module5.weight)

tensor([[-0.0000,  0.0000,  0.0407,  0.0687, -0.0731, -0.0717,  0.0000, -0.0000,
          0.0559,  0.0000, -0.0000,  0.1006, -0.0579, -0.0000, -0.0556, -0.0644,
          0.0444, -0.0422, -0.0488, -0.0573,  0.0707, -0.0927,  0.0407,  0.0000,
          0.1050,  0.0000, -0.0963, -0.0000,  0.0000, -0.0000, -0.0000,  0.0911,
         -0.0000,  0.0627, -0.0000, -0.0000, -0.0000, -0.1035,  0.0670,  0.0689,
          0.0956,  0.0000,  0.0932,  0.0000, -0.0000, -0.0000,  0.0969,  0.0000,
         -0.0000, -0.0926, -0.1000,  0.0536,  0.1029,  0.0457, -0.0000, -0.0838,
         -0.0846, -0.0883, -0.0000, -0.0800, -0.0601, -0.0663,  0.0000, -0.0000,
         -0.0000,  0.0000, -0.0981, -0.0000,  0.0756, -0.0685, -0.0000, -0.0656,
          0.0446, -0.0808,  0.0861, -0.0000,  0.0736, -0.0000, -0.1009,  0.0430,
          0.0000, -0.0000, -0.0419, -0.0808],
        [ 0.1074,  0.0989, -0.0910, -0.1057, -0.0538,  0.0886, -0.0662,  0.0798,
          0.0652,  0.0880, -0.0934, -0.0956,  0.0000, -0.0769, 

# Global Pruning

In Global Pruning, we **prune the model all at once**, by removing (for example) the lowest 20% of connections across the whole model, instead of removing the lowest 20% of connections in each layer. 

This is likely to result in different pruning percentages per layer and we use it using **global_unstructured** from **torch.nn.utils.prune**.

In [34]:
# Instantiate the Model
model = LeNet()

# Define the Layers and their Parameters (weight/bias) to Prune
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

# Prun the whole model Globally i.e.
# Assign same Percentage of pruning for all layers instead of defining
# Pruning percentages per layer.
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    # 20% pruninig across all layers combined
    amount=0.2,
)

In [35]:
model

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [36]:
# Check all Model layers Original Values
# All layers show 'weight_orig' as the model has been pruned for weight
# For name and params in all model named_parameters
for name, param in model.named_parameters():
  # If the parameter is updatable i.e. requires_grad = True
  if param.requires_grad:
    # Print the Layer Name and it's Parameters
    print (name, param.data)

conv1.bias tensor([-0.1796, -0.1487, -0.1289,  0.0827,  0.0039, -0.1429])
conv1.weight_orig tensor([[[[ 0.2952, -0.0966, -0.1436],
          [-0.3235, -0.1454, -0.0882],
          [-0.1662,  0.0028,  0.0075]]],


        [[[ 0.0438,  0.2900,  0.2357],
          [ 0.1284, -0.1668,  0.2187],
          [ 0.1519,  0.0483,  0.0010]]],


        [[[ 0.0661, -0.1535,  0.2512],
          [-0.0804, -0.1958, -0.1135],
          [ 0.1595,  0.3300, -0.0040]]],


        [[[ 0.2241, -0.2220,  0.2834],
          [ 0.1820,  0.3208,  0.0883],
          [ 0.0850, -0.3110,  0.1855]]],


        [[[ 0.0314,  0.1655, -0.1768],
          [-0.3055,  0.1989,  0.0736],
          [ 0.2201,  0.2403, -0.1500]]],


        [[[-0.1379, -0.2519, -0.0242],
          [ 0.3203,  0.1620, -0.1227],
          [-0.1037, -0.2465,  0.2425]]]])
conv2.bias tensor([-0.1147,  0.0860, -0.1281, -0.0708,  0.1334, -0.0463,  0.0458,  0.0561,
         0.0008,  0.1178, -0.0454,  0.0139, -0.0689,  0.0034,  0.0701,  0.0936])
conv2.weigh

Now we can check the sparsity induced in every pruned parameter, which **will not be equal to 20% in each layer**. However, the **global sparsity will be (approximately) 20%**.

In [37]:
# Print the Sparsity per Layer for the Pruned Model
# The total Sparsity of the Model will be around 20%

# Sparsity % Calculation per Layer = 
# 100 * (Number of zero weights after pruning in the layer) / (Total Number of Weights in the layer)

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)

print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)

print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)

print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)

print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)

# Print Global Sparsity of the Model across all layers

# Sparsity % Calculation for the Model = 
# 100 * sum(Number of zero weights after pruning in each layer of the model layer) / sum(Total Number of Weights in each layer)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

Sparsity in conv1.weight: 7.41%
Sparsity in conv2.weight: 9.03%
Sparsity in fc1.weight: 21.98%
Sparsity in fc2.weight: 12.46%
Sparsity in fc3.weight: 9.76%
Global sparsity: 20.00%


# Extending "torch.nn.utils.prune" Functionality with Custom Pruning Functions

## Steps for creating Custom Pruning Functions

1. To implement your own pruning function, you can extend the **nn.utils.prune module** by subclassing the **BasePruningMethod** base class, the same way all other pruning methods do. 

2. The base class implements the following methods for you: 
  
  **_ __call_ _ _**

  **apply_mask**

  **apply**

  **prune**

  **remove**

  Beyond some special cases, you shouldn’t have to reimplement these methods for your new pruning technique. 
  
3. You will, however, have to implement __init__ (the constructor), and compute_mask (the instructions on how to compute the mask for the given tensor according to the logic of your pruning technique). 

4. In addition, you will have to specify which type of pruning this technique implements (supported options are **global, structured, and unstructured**). This is needed to determine **how to combine masks in the case in which pruning is applied** iteratively. In other words, when pruning a pre-pruned parameter, the current prunining techique is expected to act on the unpruned portion of the parameter. Specifying the **PRUNING_TYPE** will enable the **PruningContainer** (which handles the iterative application of pruning masks) to correctly identify the slice of the parameter to prune.

Let’s assume, for example, that you want to implement a pruning technique that **prunes every other entry in a tensor** (or – if the tensor has previously been pruned – in the remaining unpruned portion of the tensor). This will be of **PRUNING_TYPE='unstructured'** because it acts on individual connections in a layer and not on entire **units/channels ('structured')**, or **across different parameters ('global')**.

In [38]:
# Firstly, create the Custom Pruning Class with the Method/Function
# This custom Class Subclasses the "BasePruningMethod".
class customPruningMethod(prune.BasePruningMethod):
  # Define the PRUNING TYPE i.e. structured, unstructured or global
  PRUNING_TYPE = 'unstructured'

  # Define the function for Custom Pruning
  # Here we are defining a function that prunes every alternate value in the tensor
  # t: input tensor, default_mask: 'weight' or 'bias'
  def compute_mask(self, t, default_mask):
    # Crate a Copy of the 'default_mask' with same size/shape as the tensor
    mask = default_mask.clone()
    # Create a Mask with Every Alternate Values in the 'mask' equals to '0'
    # [::2] => all rows, columns and hop of 2
    mask.view(-1)[::2] = 0
    return mask

Now, to apply this to a parameter in an nn.Module, you should also provide **a simple function that instantiates the method and applies it**.

In [39]:
# Create a Function to call the Custom Pruning function from the Class
def custom_unstructured(module, name):
  """
  Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module)
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the
    original (unpruned) parameter is stored in a new parameter named
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module

    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
  """
  customPruningMethod.apply(module, name)
  return module

## Try out the Custom Pruning Function

In [40]:
# Instantiate the Model
model = LeNet()
model

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [41]:
# Original Model FC3 Layer Weights
model.fc3.weight.data

tensor([[ 5.1730e-02, -1.0493e-01,  3.7107e-02, -8.1432e-02,  3.6493e-03,
         -1.6143e-02, -5.9364e-02, -7.9780e-02,  4.9360e-03,  2.9754e-02,
         -1.0489e-01,  4.8641e-02, -2.4668e-02, -2.8385e-02, -7.2158e-02,
         -7.9618e-02,  6.7679e-02,  7.3266e-02,  1.2629e-03,  4.1988e-02,
          8.9891e-02, -9.7443e-02, -7.8273e-02,  3.7919e-02, -2.4429e-02,
          5.4867e-02,  9.6588e-02,  1.0614e-01, -9.7069e-02, -9.4999e-02,
          9.5408e-02,  5.3908e-02, -8.8701e-02,  1.7467e-04,  1.9297e-02,
          1.0780e-01, -2.4014e-02,  2.8175e-02,  3.4580e-02,  7.2127e-03,
         -9.9485e-02, -6.3003e-03,  7.2915e-02, -1.3120e-02,  8.7878e-02,
          3.6286e-02,  9.8314e-03, -1.0587e-01, -4.7081e-02, -5.5500e-03,
          3.6411e-02,  5.2733e-02, -3.4934e-02, -7.0024e-02, -1.0029e-01,
         -2.9627e-04,  5.3550e-02, -2.2353e-02, -8.7671e-02,  2.2571e-02,
         -8.5857e-02, -6.5263e-02,  1.6634e-02,  6.1099e-03,  2.5267e-02,
          9.3630e-02,  7.3828e-02, -8.

In [42]:
# Shape of Original FC3 Layer Weight data
model.fc3.weight.data.shape

torch.Size([10, 84])

In [43]:
# Original Model FC3 Layer Bias Data
model.fc3.bias.data

tensor([-0.0654, -0.0262, -0.0466, -0.0926, -0.0189,  0.0428,  0.0982, -0.0424,
         0.0266,  0.0498])

In [44]:
# Shape of Original FC3 Layer Bias data
model.fc3.bias.data.shape

torch.Size([10])

In [45]:
u = model.fc3.bias.data
u.view(-1)[::2] = 0
u

tensor([ 0.0000, -0.0262,  0.0000, -0.0926,  0.0000,  0.0428,  0.0000, -0.0424,
         0.0000,  0.0498])

In [46]:
# Apply Custom Pruning Function
custom_unstructured(module=model.fc3, name='bias')

Linear(in_features=84, out_features=10, bias=True)

In [47]:
# Module Original Named Parameters
# Since the Pruning Mask has been applied, the 'bias' changes to 'bias_orig'
print(list(model.fc3.named_parameters()))

[('weight', Parameter containing:
tensor([[ 5.1730e-02, -1.0493e-01,  3.7107e-02, -8.1432e-02,  3.6493e-03,
         -1.6143e-02, -5.9364e-02, -7.9780e-02,  4.9360e-03,  2.9754e-02,
         -1.0489e-01,  4.8641e-02, -2.4668e-02, -2.8385e-02, -7.2158e-02,
         -7.9618e-02,  6.7679e-02,  7.3266e-02,  1.2629e-03,  4.1988e-02,
          8.9891e-02, -9.7443e-02, -7.8273e-02,  3.7919e-02, -2.4429e-02,
          5.4867e-02,  9.6588e-02,  1.0614e-01, -9.7069e-02, -9.4999e-02,
          9.5408e-02,  5.3908e-02, -8.8701e-02,  1.7467e-04,  1.9297e-02,
          1.0780e-01, -2.4014e-02,  2.8175e-02,  3.4580e-02,  7.2127e-03,
         -9.9485e-02, -6.3003e-03,  7.2915e-02, -1.3120e-02,  8.7878e-02,
          3.6286e-02,  9.8314e-03, -1.0587e-01, -4.7081e-02, -5.5500e-03,
          3.6411e-02,  5.2733e-02, -3.4934e-02, -7.0024e-02, -1.0029e-01,
         -2.9627e-04,  5.3550e-02, -2.2353e-02, -8.7671e-02,  2.2571e-02,
         -8.5857e-02, -6.5263e-02,  1.6634e-02,  6.1099e-03,  2.5267e-02,
    

In [48]:
# Custom Mask Applied to the Bias values
print(list(model.fc3.named_buffers()))

[('bias_mask', tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.]))]


In [49]:
# Access the 'bias_mask' from the Model
print(model.fc3.bias_mask)

tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])


In [50]:
model.fc3._forward_pre_hooks

OrderedDict([(13, <__main__.customPruningMethod at 0x7f2ec98922b0>)])

In [51]:
# The above pruning  will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
print(module.weight)

tensor([[-0.0000,  0.0000,  0.0407,  0.0687, -0.0731, -0.0717,  0.0000, -0.0000,
          0.0559,  0.0000, -0.0000,  0.1006, -0.0579, -0.0000, -0.0556, -0.0644,
          0.0444, -0.0422, -0.0488, -0.0573,  0.0707, -0.0927,  0.0407,  0.0000,
          0.1050,  0.0000, -0.0963, -0.0000,  0.0000, -0.0000, -0.0000,  0.0911,
         -0.0000,  0.0627, -0.0000, -0.0000, -0.0000, -0.1035,  0.0670,  0.0689,
          0.0956,  0.0000,  0.0932,  0.0000, -0.0000, -0.0000,  0.0969,  0.0000,
         -0.0000, -0.0926, -0.1000,  0.0536,  0.1029,  0.0457, -0.0000, -0.0838,
         -0.0846, -0.0883, -0.0000, -0.0800, -0.0601, -0.0663,  0.0000, -0.0000,
         -0.0000,  0.0000, -0.0981, -0.0000,  0.0756, -0.0685, -0.0000, -0.0656,
          0.0446, -0.0808,  0.0861, -0.0000,  0.0736, -0.0000, -0.1009,  0.0430,
          0.0000, -0.0000, -0.0419, -0.0808],
        [ 0.1074,  0.0989, -0.0910, -0.1057, -0.0538,  0.0886, -0.0662,  0.0798,
          0.0652,  0.0880, -0.0934, -0.0956,  0.0000, -0.0769, 

In [52]:
print(module._forward_pre_hooks)

OrderedDict([(7, <torch.nn.utils.prune.L1Unstructured object at 0x7f2ec98eb6a0>)])
