# PyTorch Pruning Tutorial

This is a tutorial specifically designed for the LotteryFL implementation. The methods and techniques listed below are used liberally throughout this repo. Below are all of the dependencies you need for this tutorial

In [1]:
import copy
import numpy as np
import collections
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from util import prune_fixed_amount, get_prune_params, train, create_model
from tabulate import tabulate

## 1. PyTorch Model Primer

This sections covers the structure of a PyTorch Module/Model. While TensorFlow has `tf.keras.Model` and `tf.keras.layers.Layer`, PyTorch only has `nn.Module`. This means both the model itself and its various layers are all sub-classes of `nn.Modules`. Therefore, both the model (e.g. VGG) and the layers (e.g. Conv2d) has `.named_parameters()` and `.named_buffers()` methods.

When we say a PyTorch 'model', we are referring to an instance of `nn.Module` that may or may not have many sub-modules. Each of the sub-modules can also contain many sub-sub-modules. To demonstrate, below is a code snippet that allows you to access all of sub-modules in a PyTorch model recursively.

In [2]:
model = nn.Sequential(nn.Linear(100, 10), nn.Linear(10, 1))
for module in model.modules():
    print(module, end=',\n\n')

Sequential(
  (0): Linear(in_features=100, out_features=10, bias=True)
  (1): Linear(in_features=10, out_features=1, bias=True)
),

Linear(in_features=100, out_features=10, bias=True),

Linear(in_features=10, out_features=1, bias=True),



As we can see above, the modules returned by `model.modules()` include the model itself as well as the sub-modules. Below is a more complicated example where a `nn.Sequential` module is nested in another `nn.Sequential` module.

In [3]:
model = nn.Sequential(nn.Linear(100, 10), 
                      nn.Linear(10, 10),
                      nn.Sequential(
                          nn.Linear(10, 3),
                          nn.Linear(3, 1)
                      )
                     )
for module in model.modules():
    print(module, end=',\n\n')

Sequential(
  (0): Linear(in_features=100, out_features=10, bias=True)
  (1): Linear(in_features=10, out_features=10, bias=True)
  (2): Sequential(
    (0): Linear(in_features=10, out_features=3, bias=True)
    (1): Linear(in_features=3, out_features=1, bias=True)
  )
),

Linear(in_features=100, out_features=10, bias=True),

Linear(in_features=10, out_features=10, bias=True),

Sequential(
  (0): Linear(in_features=10, out_features=3, bias=True)
  (1): Linear(in_features=3, out_features=1, bias=True)
),

Linear(in_features=10, out_features=3, bias=True),

Linear(in_features=3, out_features=1, bias=True),



If we want to get a list of sub-modules to be pruned for a model, we need to check for whether the sub-modules returned by `model.modules()` are: 1. the model itself, 2. a `nn.Sequential` module. If the sub-modules is either 1. or 2., then we need to stop them from being pruned to avoid repetition. Below is the `get_prune_params()` method from `util.py`. We can see that the outer for-loop checks for whether a particular module satisfies either 1. or 2. 

In [4]:
def get_prune_params(model):
    layers = []
    
    num_global_weights = 0
    
    modules = list(model.modules())
    
    for layer in modules:
        
        is_sequential = type(layer) == nn.Sequential
        
        is_itself = type(layer) == type(model) if len(modules) > 1 else False
        
        if (not is_sequential) and (not is_itself):
            for name, param in layer.named_parameters():
                
                field_name = name.split('.')[-1]
                
                # This might break if someone does not adhere to the naming
                # convention where weights of a module is stored in a field
                # that has the word 'weight' in it
                
                if 'weight' in field_name and param.requires_grad:
                    
                    if field_name.endswith('_orig'):
                        field_name = field_name[:-5]
                    
                    # Might remove the param.requires_grad condition in the future
                    
                    layers.append((layer, field_name))
                
                    num_global_weights += torch.numel(param)
                    
    return layers, num_global_weights

The inner for loop goes through all of the parameters for each sub-modules. If the name of the parameter contains the word 'weight', then that means it is a weight instead of a bias. This might seem a 'hack' to distinguish which parameters are wieghts and which parameters are bias, but we have yet to figure out a better way to do it. In addition to checking whether a parameter is a weight, we also check if it is accounted for during the gradient calculation by check if `param.requires_grad` is True. Finally, the `get_prune_params()` method would a return a list of tuples. Each tuple is of the form `(nn.Module, str)` where the first entry is the reference of the module to be pruned, and the second entry is the name of the parameter to be pruned. We will now go to the second module where we can see how we can prune the model giving `get_prune_params()`.

You might be wondering why we include the following code in the `get_prune_params()` method. We will get to that in the next section.

``` python
if field_name.endswith('_orig'):
    field_name = field_name[:-5]
```

**Donglin's Note:** As I am writing this, I realized that it is easier to just go through the model (aka the 'root' module), call the `.named_parameters()`, and get all of the weights without the need of the inner for-loop. However, this does not allow us to see how much each 'layer' is getting pruned. I guess we can call this a feature.

## 2. PyTorch Pruning Primer
Let's start by defining a dummy MLP model with 2 layers of 10 and 2 neurons each (20 parameters total). Below, we can see the difference between `model.modules()` and `get_prune_params(model)` which we implemented above. As we can see below, the `get_prune_params(model)` only returns the modules that we need, whereas `model.modules()` returns a lot of 'wrapper' modules that don't need to be pruned. 

In [5]:
class DummyMLP(nn.Module):
    def __init__(self, num_classes=10):
        super(DummyMLP, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(10, 2)
        )
        
    def forward(self, x):
        return self.classifier(x)
model = DummyMLP()
print('_________________________________________________________________________________________')
print("Returns from model.modules()")
print('=========================================================================================')
for layer in model.modules():
    print(layer)
    

layers_to_prune, _ = get_prune_params(model)
print('_________________________________________________________________________________________')
print("Returns from get_prune_params(model)")
print('=========================================================================================')
for layer in layers_to_prune:
    print(layer)

_________________________________________________________________________________________
Returns from model.modules()
DummyMLP(
  (classifier): Sequential(
    (0): Linear(in_features=10, out_features=2, bias=True)
  )
)
Sequential(
  (0): Linear(in_features=10, out_features=2, bias=True)
)
Linear(in_features=10, out_features=2, bias=True)
_________________________________________________________________________________________
Returns from get_prune_params(model)
(Linear(in_features=10, out_features=2, bias=True), 'weight')


Now that we have obtained that have a way to obtain the all of the layers to prune. However, before we do any pruning, let's first look at what gets returned by `model.named_parameters()` and `model.named_buffers()`

In [6]:
print('_________________________________________________________________________________________')
print("Returns from model.named_parameters")
print('=========================================================================================')
print(list(model.named_parameters()))

print('_________________________________________________________________________________________')
print("Returns from model.named_buffers")
print('=========================================================================================')
print(list(model.named_buffers()))


_________________________________________________________________________________________
Returns from model.named_parameters
[('classifier.0.weight', Parameter containing:
tensor([[ 2.8054e-01,  2.2104e-04,  1.6230e-02,  1.1865e-01,  3.1461e-02,
          1.3736e-01,  1.3134e-01, -2.8906e-01,  2.0455e-01, -4.2221e-02],
        [ 2.4345e-01,  2.0637e-01, -2.5642e-01, -8.5109e-02, -6.6805e-02,
         -1.5327e-01, -2.7775e-01, -1.9317e-01, -2.6005e-01,  2.7742e-01]],
       requires_grad=True)), ('classifier.0.bias', Parameter containing:
tensor([-0.2621,  0.0467], requires_grad=True))]
_________________________________________________________________________________________
Returns from model.named_buffers
[]


The following is the `.weight` attribute of a particular layer/module.

In [7]:
print(layers_to_prune[0][0].weight)

Parameter containing:
tensor([[ 2.8054e-01,  2.2104e-04,  1.6230e-02,  1.1865e-01,  3.1461e-02,
          1.3736e-01,  1.3134e-01, -2.8906e-01,  2.0455e-01, -4.2221e-02],
        [ 2.4345e-01,  2.0637e-01, -2.5642e-01, -8.5109e-02, -6.6805e-02,
         -1.5327e-01, -2.7775e-01, -1.9317e-01, -2.6005e-01,  2.7742e-01]],
       requires_grad=True)


Now we will do the pruning and see how the pruning will affect the inner structure of the model. We will use the `layers_to_prune` obtained from `get_prune_params()` and pass it into the `torch.nn.utils.prune.global_unstructured()` method. This method does the pruning in-place. We do not have to pass the entire model to `global_unstructured()`, only the `layers_to_prune`. This is because that `layers_to_prune` contains the references to the model layers. Therefore, any pruning can be done in-place using the references in `layers_to_prune` without the need to access the entire model

In [8]:
torch.nn.utils.prune.global_unstructured(layers_to_prune,
                                         pruning_method=prune.L1Unstructured,
                                         amount = 10)
print(list(model.named_parameters()))
print(list(model.named_buffers()))

[('classifier.0.bias', Parameter containing:
tensor([-0.2621,  0.0467], requires_grad=True)), ('classifier.0.weight_orig', Parameter containing:
tensor([[ 2.8054e-01,  2.2104e-04,  1.6230e-02,  1.1865e-01,  3.1461e-02,
          1.3736e-01,  1.3134e-01, -2.8906e-01,  2.0455e-01, -4.2221e-02],
        [ 2.4345e-01,  2.0637e-01, -2.5642e-01, -8.5109e-02, -6.6805e-02,
         -1.5327e-01, -2.7775e-01, -1.9317e-01, -2.6005e-01,  2.7742e-01]],
       requires_grad=True))]
[('classifier.0.weight_mask', tensor([[1., 0., 0., 0., 0., 0., 0., 1., 1., 0.],
        [1., 1., 1., 0., 0., 0., 1., 1., 1., 1.]]))]


As we can see from above, the pruning module takes the original weights and add a '_orig' postfix behind it. The bias terms are left unchanged. If we access the `.weight` attribute in the first and only layer/module of the DummyMLP, we can see that the `requires_grad=True` entry is replaced with `grad_fn=<MulBackward0>`. This is done by the pruning module internally for gradient calculation.

In [9]:
print(layers_to_prune[0][0].weight)

tensor([[ 0.2805,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.2891,
          0.2045, -0.0000],
        [ 0.2435,  0.2064, -0.2564, -0.0000, -0.0000, -0.0000, -0.2777, -0.1932,
         -0.2601,  0.2774]], grad_fn=<MulBackward0>)


We have now covered all of the methods needed to construct the `prune_fixed_amount()` method in `util.py`. However, before we go through end the section, there is an importnat caveat we need to touch on. You might recall that we metioned the following snippet in the `get_prune_params()` method. 

``` python
if field_name.endswith('_orig'):
    field_name = field_name[:-5]
```

This is to check if a model has been pruned before. Below, we have defined a version of `get_prune_params()` without this code snippet. Given a DummyMLP with 20 weights, we first prune 5 weights and then 5 weights again.

In [10]:
def get_prune_params(model):
    layers = []
    
    num_global_weights = 0
    
    modules = list(model.modules())
    
    for layer in modules:
        
        is_sequential = type(layer) == nn.Sequential
        
        is_itself = type(layer) == type(model) if len(modules) > 1 else False
        
        if (not is_sequential) and (not is_itself):
            for name, param in layer.named_parameters():
                
                field_name = name.split('.')[-1]
                
                # This might break if someone does not adhere to the naming
                # convention where weights of a module is stored in a field
                # that has the word 'weight' in it
                
                if 'weight' in field_name and param.requires_grad:
                    
                    # Might remove the param.requires_grad condition in the future
                    
                    layers.append((layer, field_name))
                
                    num_global_weights += torch.numel(param)
                    
    return layers, num_global_weights

def prune_fixed_amount(model, amount, verbose=True):
    parameters_to_prune, num_global_weights = get_prune_params(model)
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=amount)

    num_global_zeros, num_layer_zeros, num_layer_weights = 0, 0, 0
    global_prune_percent, layer_prune_percent = 0, 0
    prune_stat = {'Layers': [],
                  'Weight Name': [],
                  'Percent Pruned': [],
                  'Total Pruned': []}
    
    # Pruning is done in-place, thus parameters_to_prune is updated
    for layer, weight_name in parameters_to_prune:
        
        num_layer_zeros = torch.sum(getattr(layer, weight_name) == 0.0).item()
        num_global_zeros += num_layer_zeros
        num_layer_weights = torch.numel(getattr(layer, weight_name))
        layer_prune_percent = num_layer_zeros / num_layer_weights * 100
        prune_stat['Layers'].append(layer.__str__())
        prune_stat['Weight Name'].append(weight_name)
        prune_stat['Percent Pruned'].append(f'{num_layer_zeros} / {num_layer_weights} ({layer_prune_percent:.5f}%)')
        prune_stat['Total Pruned'].append(f'{num_layer_zeros}')
        
    global_prune_percent = num_global_zeros / num_global_weights
    if verbose:
        print('Pruning Summary', flush=True)
        print(tabulate(prune_stat, headers='keys'), flush=True)
        print(f'Percent Pruned Globaly: {global_prune_percent:.2f}', flush=True)

model = DummyMLP()

print('_________________________________________________________________________________________')
print("Model after 1st round of pruning")
print('=========================================================================================')

prune_fixed_amount(model, 5)
print()
print(list(model.named_parameters()))

print('_________________________________________________________________________________________')
print("Model after 2nd round of pruning")
print('=========================================================================================')
prune_fixed_amount(model, 5)
print()
print(list(model.named_parameters()))

_________________________________________________________________________________________
Model after 1st round of pruning
Pruning Summary
Layers                                             Weight Name    Percent Pruned        Total Pruned
-------------------------------------------------  -------------  ------------------  --------------
Linear(in_features=10, out_features=2, bias=True)  weight         5 / 20 (25.00000%)               5
Percent Pruned Globaly: 0.25

[('classifier.0.bias', Parameter containing:
tensor([-0.2877, -0.3014], requires_grad=True)), ('classifier.0.weight_orig', Parameter containing:
tensor([[ 0.2245,  0.2599,  0.2035, -0.3123,  0.0654, -0.1012,  0.1646,  0.2186,
         -0.2231, -0.1974],
        [ 0.0487, -0.3001,  0.0796, -0.1016,  0.2889, -0.2548,  0.0478,  0.0560,
          0.2037,  0.2126]], requires_grad=True))]
_________________________________________________________________________________________
Model after 2nd round of pruning
Pruning Summary
Lay

As we can see here, if we don't check for whether or not a weight has been pruned, it will treat 'weight_orig' as the actual name of the weight instead of just 'weight'. We can also see that, even though we have pruned 10 weights in total across the two rounds, the second pruning summary stil shows that we have only pruned 5 weights. This is because that, during the second round of pruning, if we use 'weight_orig' as the name of the parameter to be pruned, PyTorch pruning will think that this is a branch new set of weights with a completely different name. PyTorch's pruning module will not make the association between 'weight_orig' in the second round and the 'weight' in the first round. 

## 3. Changing & Copying weights

Other than the `get_prune_params` and `prune_fixed_amount` methods, we also have a lot of methods that takes care of copying models in `util.py`. As we have shown above, PyTorch's pruning methods stores the pruned weights in the `.weight` attribute of each layer. In addition, the pruning method will also add a '_orig' postfix behind the originial weight names. However, as illustrated below, changing the `.weight` field will have not effect on the final model output. Instead, we need to iterate through the named parameters of each layers and change them individually. The same techniques apply when copying the named buffers (masks) of a model.

In [11]:
model = DummyMLP()

layers, _ = get_prune_params(model)
param_shape = layers[0][0].weight.shape
prune_fixed_amount(model, 0, verbose=False)

print('Changing the "weight" directly will not work')
layers[0][0].weight = np.zeros(param_shape)
output = model(torch.zeros(10,)) 
print(output)

print('Changing the named parameters WILL work')
for name, params in layers[0][0].named_parameters():
    params.data = torch.zeros(params.shape)

output = model(torch.zeros(10,)) 
print(output)


Changing the "weight" directly will not work
tensor([-0.1959, -0.2687], grad_fn=<AddBackward0>)
Changing the named parameters WILL work
tensor([0., 0.], grad_fn=<AddBackward0>)
