# Network Pruning in PyTorch

PyTorch supports post-training network pruning through the module called torch.nn.utils.prune. Given a trained network, pruning can be achieved by passing the model to the global_unstructured function. Once the model is pruned, a binary mask is attached which represents the set of parameters that are pruned. The mask is applied on the target parameter prior to forward operation, eliminating the unnecessary computations.

### Create a sample model

In [1]:
import torch
import torch.quantization
import torch.nn as nn

torch.manual_seed(0)  # set the seed for reproducibility

class SampleLinearModel(nn.Module): 

    def __init__(self): 
        super(SampleLinearModel, self).__init__() 
        self.linear1 = nn.Linear(10, 10) 

    def forward(self, x): 
        x = self.linear1(x) 
        return x 

### Original model for comparison

In [2]:
original_model = SampleLinearModel()
print(original_model)

SampleLinearModel(
  (linear1): Linear(in_features=10, out_features=10, bias=True)
)


In [3]:
for param_name, param in original_model.named_parameters():
    print(param_name)

linear1.weight
linear1.bias


### Create pruned model

In this example, we are pruning the lowest 50% of the weights based on L1-norm.

In [4]:
pruned_model = SampleLinearModel()
print(pruned_model)

SampleLinearModel(
  (linear1): Linear(in_features=10, out_features=10, bias=True)
)


In [5]:
for param_name, param in pruned_model.named_parameters():
    print(param_name)

linear1.weight
linear1.bias


In [6]:
import torch.nn.utils.prune as prune

parameters_to_prune = ( 
    (pruned_model.linear1, 'weight'),
) 

prune.global_unstructured(
    parameters_to_prune, 
    pruning_method=prune.L1Unstructured,
    amount=0.5
)

prunned module will now have a mask and pre hooks for skipping the masked weights

In [7]:
pruned_model.linear1.weight_mask

tensor([[0., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 0., 0., 0., 1.],
        [0., 0., 1., 0., 1., 0., 0., 0., 0., 1.],
        [1., 1., 1., 0., 1., 0., 1., 0., 1., 1.],
        [1., 1., 0., 1., 1., 1., 1., 1., 0., 0.],
        [0., 0., 1., 1., 0., 0., 0., 0., 1., 0.],
        [1., 0., 1., 0., 1., 1., 0., 0., 1., 1.],
        [0., 0., 1., 0., 0., 0., 0., 0., 1., 1.],
        [1., 1., 1., 0., 1., 1., 1., 0., 1., 0.],
        [0., 1., 1., 1., 1., 1., 0., 1., 0., 0.]])

In [8]:
pruned_model.linear1._forward_pre_hooks

OrderedDict([(0, <torch.nn.utils.prune.CustomFromMask at 0x7f5e5877a320>)])