![logo](../../picture/license_header_logo.png)
**Copyright (c) 2020-2021 CertifAI Sdn. Bhd.**

This program is part of OSRFramework. You can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public License
along with this program. If not, see http://www.gnu.org/licenses/.

Authored by: [Jacklyn Lim](mailto:jacklyn.lim@certifai.ai)

### Import Libraries

In [1]:
import copy
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from utils import download_model, download_dataset, load_model_state_dict, load_dataset, load_image, inspect_module, compare_performance

### Define Model Architecture

In [2]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # Note that the input of this layers is depending on your input image sizes
        self.fc1 = nn.Linear(18496, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 3)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

### Download and Load Model

In [3]:
# model download
MODEL_DOWNLOAD_PATH = 'https://s3.eu-central-1.wasabisys.com/certifai/deployment-training-labs/models/fruit_classifier_state_dict.pt'
MODEL_STATE_DICT_PATH = '../../resources/model/'
MODEL_FILENAME = 'fruits_image_classification.zip'
download_model(MODEL_DOWNLOAD_PATH, MODEL_STATE_DICT_PATH, MODEL_FILENAME)

# data download
DATA_DOWNLOAD_PATH = "https://s3.eu-central-1.wasabisys.com/certifai/deployment-training-labs/fruits_image_classification-20210604T123547Z-001.zip"
DATA_SAVE_PATH = "../../resources/data/"
DATA_ZIP_FILENAME = "fruits_image_classification.zip"
download_dataset(DATA_DOWNLOAD_PATH, DATA_SAVE_PATH, DATA_ZIP_FILENAME)

Downloading from https://s3.eu-central-1.wasabisys.com/certifai/deployment-training-labs/models/fruit_classifier_state_dict.pt to ../../resources/model/
100% [..........................................................................] 8935167 / 8935167
Done!
data already exists, skipping download


### Load Model

In [4]:
model = Net()
model = load_model_state_dict(model, MODEL_STATE_DICT_PATH + MODEL_FILENAME)

original_model = copy.deepcopy(model)
single_pruned_model = copy.deepcopy(model)
globally_pruned_model = copy.deepcopy(model)

### Iteratively Prune Model

The only thing difference between single pruned model and iteratively pruned model is that in iteratively pruned models, the same parameter is pruned multiple times instead of just once.  

In [5]:
def iterative_prune(model, modules, num_iterations, global_pruning = False):
    
    if (global_pruning):
        
        for i in range(num_iterations):
            prune.global_unstructured(
                modules,
                pruning_method=prune.L1Unstructured, 
                amount=0.3,
            )
    else:
        for i in range(num_iterations):
            prune.l1_unstructured(module, name="weight", amount=0.3)

In [6]:
# local single pruning
module = single_pruned_model.conv1
iterative_prune(single_pruned_model, module, num_iterations=5)

# global pruning
modules_to_prune = (
    (globally_pruned_model.conv1, 'weight'),
    (globally_pruned_model.conv2, 'weight'),
    (globally_pruned_model.fc1, 'weight'),
    (globally_pruned_model.fc2, 'weight'),
    (globally_pruned_model.fc3, 'weight'),
)
iterative_prune(globally_pruned_model, modules_to_prune, num_iterations=5, global_pruning=True)

The various pruning calls performed by iterative pruning is just a combination of the various masks applied in sequence. The combination of a new mask with the old mask is handled by the ```PruningContainer```'s ```compute_mask``` method.

We can see that The corresponding hook will now be of type torch.nn.utils.prune.PruningContainer, and will store the history of pruning applied to the weight parameter.

In [7]:
for hook in module._forward_pre_hooks.values():
    print(hook)
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container

<torch.nn.utils.prune.PruningContainer object at 0x0000014D8B8B5610>
[<torch.nn.utils.prune.L1Unstructured object at 0x0000014DBF2F2F70>, <torch.nn.utils.prune.L1Unstructured object at 0x0000014DBE2DE670>, <torch.nn.utils.prune.L1Unstructured object at 0x0000014D8B8B56D0>, <torch.nn.utils.prune.L1Unstructured object at 0x0000014DBE2D2850>, <torch.nn.utils.prune.L1Unstructured object at 0x0000014DBE2D28E0>]


### Remove Pruning Re-parametrization 

To make the pruning permanent (i.e.: zero out the parameters), we need to remove the re-parametrization in terms
of ``weight_orig`` and ``weight_mask``, and remove the ``forward_pre_hook``, we can use the ``remove`` functionality from ``torch.nn.utils.prune``. 

What we are doing here is reassigning the attribute ``weight`` to the model parameters, in its pruned version.

In [8]:
# re-parametrization for single pruning
prune.remove(module, 'weight')

# re-parametrization for global pruning
for module, parameter in modules_to_prune:
    prune.remove(module, parameter)

### Compare Model Performance

In [9]:
INFERENCE_IMAGE_PATH = "../../resources/data/fruits_image_classification/test/apple/image1.jpg"
TEST_DATASET_ROOTDIR = "../../resources/data/fruits_image_classification/test"

# load image
inference_image = load_image(INFERENCE_IMAGE_PATH)

# load test dataset
test_dataloader = load_dataset(TEST_DATASET_ROOTDIR)

print("Original model vs single pruned model:")
compare_performance(original_model, single_pruned_model, "original_model", "locally_pruned_model", inference_image, test_dataloader)

print("\n\nOriginal model vs globally pruned model:")
compare_performance(original_model, globally_pruned_model, "original_model", "globally_pruned_model", inference_image, test_dataloader)

Original model vs single pruned model:
Comparing size of models
model:  original_model  	 Size (KB): 8935.103
model:  locally_pruned_model  	 Size (KB): 8935.103
1.00 times smaller

Comparing latency of models
model:  original_model  	 prediction time: 0.007002353668212891s
model:  locally_pruned_model  	 prediction time: 0.0019996166229248047s

Comparing accuracy of models
model:  original_model  	 Test Accuracy: 0.74
model:  locally_pruned_model  	 Test Accuracy: 0.76


Original model vs globally pruned model:
Comparing size of models
model:  original_model  	 Size (KB): 8935.103
model:  globally_pruned_model  	 Size (KB): 8935.103
1.00 times smaller

Comparing latency of models
model:  original_model  	 prediction time: 0.002000570297241211s
model:  globally_pruned_model  	 prediction time: 0.002000570297241211s

Comparing accuracy of models
model:  original_model  	 Test Accuracy: 0.74
model:  globally_pruned_model  	 Test Accuracy: 0.73


### Compare Sparsity

In [10]:
def calculate_global_sparsity(model):
    return 100. * float(
        torch.sum(model.conv1.weight == 0)
        + torch.sum(model.conv2.weight == 0)
        + torch.sum(model.fc1.weight == 0)
        + torch.sum(model.fc2.weight == 0)
        + torch.sum(model.fc3.weight == 0)
    ) / float(
        model.conv1.weight.nelement()
        + model.conv2.weight.nelement()
        + model.fc1.weight.nelement()
        + model.fc2.weight.nelement()
        + model.fc3.weight.nelement()
    )

In [11]:
print("Global sparsity for unpruned model: {:.2f}%".format(calculate_global_sparsity(original_model)))
print("Global sparsity for locally pruned model: {:.2f}%".format(calculate_global_sparsity(single_pruned_model)))
print("Global sparsity for globally pruned model: {:.2f}%".format(calculate_global_sparsity(globally_pruned_model)))

Global sparsity for unpruned model: 0.00%
Global sparsity for locally pruned model: 0.02%
Global sparsity for globally pruned model: 83.19%


### References

- [Pruning for Neural Networks by Lei Mao](https://leimao.github.io/article/Neural-Networks-Pruning/)
- [PyTorch Pruning by Lei Mao](https://leimao.github.io/blog/PyTorch-Pruning/)
- [Pruning Tutorial by Pytorch](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html)