[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DoranLyong/Awesome-Tensor-Architecture/blob/main/pytorch_reference/simple_reference/06_PyTorch_Acceleration_and_Optimization/05_Pruning.ipynb)



# Pruning (p.183)
```Pruning``` is a technique that ```reduces``` the number of ```model parameters``` with ```minimal effect``` on performance. 

This allows you to deploy models with:
* less ```memory```, 
* lower ```power usage```, 
* and reduced ```harward resources```. 

In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Pruning model example 
__Pruning__ can be applied to an ```nn.module```. 

Since an ```nn.module``` may consist of a ```single layer```, ```multiple layers```, or an ```entire model```, <br/>
```pruning``` can be applied to a single layer, multiple layers, or the entire model itself.

In [2]:
# Basic model 

class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d( F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d( F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Our LeNet5 model has five submodules - ```conv1```, ```conv2```, ```fc1```, ```fc2```, and ```fc3```. 

In [11]:
model = LeNet5().to(device)


# Let's look at the parameters of the conv1 layer.

for name, param in model.conv1.named_parameters():
    print(name)
    print(param.size())



weight
torch.Size([6, 3, 5, 5])
bias
torch.Size([6])


*** 
## Local and global pruning (p.184)

### (ex1) Local pruning: 
when only ```pruning``` a ```specific piece``` of our model. <br/>
With this technique we can apply ```local pruning``` to a ```single layer``` or ```module```.

In [12]:
import torch.nn.utils.prune as prune

prune.random_unstructured(  model.conv1,     # target layer 
                            name = "weight", # parameter name
                            amount=0.25,     
                        )     

Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))

In [13]:
prune.random_unstructured(  model.conv1,     # target layer 
                            name = "bias",   # parameter name
                            amount=0.25,     
                        )  

Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))

### Prune modules and parameters differently (p.185)

for example: 
* prune by module or layer type 
* apply ```pruning``` to ```conv``` layers ```differently``` than ```linear``` layers.

In [14]:
model = LeNet5().to(device)

for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d): 
        prune.random_unstructured(module, name='weight', amount=0.3) # Prune all 2D conv layers by 30%

    elif isinstance(module, torch.nn.Linear):
        prune.random_unstructured(module,  name='weight', amount=0.5) # Prune all linear layers by 50%

### (ex2) Global pruning 
when applying a pruning method to the entire model. 

for example: 
* prune ```25%``` of our ```model's parameters``` ```globally```, which would probably result in different pruning rates for each layer.

In [15]:
model = LeNet5().to(device)

parameters_to_prune = ( (model.conv1, 'weight'),
                        (model.conv2, 'weight'),
                        (model.fc1, 'weight'),
                        (model.fc2, 'weight'),
                        (model.fc3, 'weight'),
                        )

prune.global_unstructured(  parameters_to_prune,
                            pruning_method=prune.L1Unstructured,
                            amount=0.25
                        ) 

# Here, we prune 25% of all the parameters in the entire model.

***
## Custom pruning methods (p.187)
create your own pruning method. 
* use ```BasePruningMethod``` class in ```torch.nn.utils.prune```

In [16]:
class MyPruningMethod(prune.BasePruningMethod):
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask


def my_unstructured(module, name):
    MyPruningMethod.apply(module, name)
    return module

In [17]:
model = LeNet5().to(device)

my_unstructured(model.fc1, name='bias')

Linear(in_features=400, out_features=120, bias=True)