# Pruning Example
This notebook demonstrates how to use the pruning methods from this tool to compress a model. 

The example uses the MNIST dataset and a simple CNN model. The model is trained and then pruned using the methods in this tool. The pruned model is then evaluated on the test set to see how well it performs.

## Setup

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

sys.path.append('../')
sys.path.append('../src/')

import src.general as general
import src.compression.pruning as prune
import src.metrics as metrics
import src.plot as plot
import src.evaluation as eval
import src.dataset_models as data
import mnist


  from .autonotebook import tqdm as notebook_tqdm


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Get device
device = general.get_device()

# Load the dataset
dataset = data.supported_datasets["MNIST"]

#### Load an existing model

In [7]:
model_state = "../models/mnist.pt"
model_class = "models.mnist"


# Load the model
model = torch.load(model_state, map_location=torch.device(device))

#### Or create a new model

In [5]:
model = mnist.MnistModel().to(device)

In [6]:
epochs = 3

for epoch in range(1, epochs + 1):
    general.train(model, dataset)
    general.test(model, dataset)

Train: 100%|██████████| 938/938 [00:22<00:00, 40.96it/s]


Train loss: 0.6501
Train score: 81.4416


Test: 100%|██████████| 157/157 [00:01<00:00, 122.73it/s]


Test loss: 0.1628
Test score: 95.2428


Train:  96%|█████████▋| 905/938 [00:21<00:00, 41.33it/s]

In [None]:
save_model = True
save_model_path = "../models/mnist.pt"

if save_model:
    torch.save(model, save_model_path)

## Pruning
Pruning is a process of reducing the size of a machine learning model by removing unimportant weights and neurons. Pruning can be used to reduce the number of parameters in a model, thereby reducing the memory footprint and the computational complexity of the model. 

Pruning is typically done in two ways: structured pruning, which involves selectively removing a larger part of the network such as a layer or a channel, and unstructured pruning, which involves removing individual weights or neurons.

### Unstructured Pruning
Here we prune individual weights or neurons.

In [8]:
# Evaluate model performance before pruning
original_results = eval.test_and_get_results(model, dataset)
plot.print_results(**original_results)

Test: 100%|██████████| 157/157 [00:01<00:00, 126.91it/s]

Test loss: 0.0771
Test score: 97.6712
Loss: 0.077142
Score: 97.671178
Time per data point: 0.4931 ms
Model Size: 1.65 MB
Number of parameters: 431080
Number of FLOPs: 2307720
Number of MACs: 2307728





In [9]:
# Define the parameters to prune
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
)

prune.magnitude_pruning_global_unstructured(model, 0.9)

pruned_results = eval.get_results(model, dataset)

plot.print_before_after_results(original_results, pruned_results)

Test: 100%|██████████| 157/157 [00:01<00:00, 124.01it/s]

Test loss: 0.8768
Test score: 87.1119
Could not calculate FLOPS
Loss: 0.077142 -> 0.876822 (1036.63%)
Score: 97.671178 -> 87.111863 (-10.81%)
Time per data point: 0.4931 ms -> 0.5045 ms (2.32%)
Model Size: 1.65 MB -> 1.65 MB (-0.00%)
Number of parameters: 431080 -> 431080 (-0.00%)
Number of FLOPs: 2307720 -> -1 (-100.00%)
Number of MACs: 2307728 -> 2307728 (-0.00%)





In [10]:
# Print number of parameters in model
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')


# Print number of parameters pruned
total_pruned_params = sum(p.numel() for p in model.parameters() if hasattr(p, 'mask'))
print(f'{total_pruned_params:,} parameters pruned ({100 * total_pruned_params / total_params:.2f}% pruned)')

# Print number of parameters remaining
total_unpruned_params = sum(p.numel() for p in model.parameters() if not hasattr(p, 'mask'))
print(f'{total_unpruned_params:,} parameters unpruned ({100 * total_unpruned_params / total_params:.2f}% unpruned)')


431,080 total parameters.
0 parameters pruned (0.00% pruned)
431,080 parameters unpruned (100.00% unpruned)


## Structured Pruning

In [11]:
import torch_pruning as tp
from torchvision.models import resnet18
from torchvision.models.resnet import ResNet18_Weights

# 0. prepare your model and example inputs
model = resnet18(weights=ResNet18_Weights.DEFAULT).eval()
example_inputs = torch.randn(1,3,224,224)
# input_batch = next(iter(train_loader))
# example_input = input_batch[0][0]


# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)

# 2. Select some channels to prune. Here we prune the channels indexed by [2, 6, 9].
pruning_idxs = pruning_idxs=[2, 6, 9]
pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )

# 3. prune all grouped layer that is coupled with model.conv1
if DG.check_pruning_group(pruning_group):
    pruning_group.exec()

In [12]:
print("After pruning:")
metrics = {
    "model": model,
    "batch_size": batch_size,
    "example_input": example_input,
}
plot.print_metrics(**metrics)

After pruning:


NameError: name 'batch_size' is not defined

In [20]:
list(model.children())

[Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
 BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 ReLU(inplace=True),
 MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
 Sequential(
   (0): BasicBlock(
     (conv1): Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace=True)
     (conv2): Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn2): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   )
   (1): BasicBlock(
     (conv1): Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace=True)
     (conv2): Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), pad

In [None]:
print(pruning_group)



--------------------------------
          Pruning Group
--------------------------------
[0] [DEP] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), #Pruned=3
[1] [DEP] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #Pruned=3
[2] [DEP] prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp(ReluBackward0), #Pruned=3
[3] [DEP] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_out_channels on _ElementWiseOp(MaxPool2DWithIndicesBackward0), #Pruned=3
[4] [DEP] prune_out_channels on _ElementWiseOp(MaxPool2DWithIndicesBackward0) => prune_out_channels on _

In [None]:
model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)

# 0. importance criterion for parameter selections
imp = tp.importance.MagnitudeImportance(p=2, group_reduction='mean')

# 1. ignore some layers that should not be pruned, e.g., the final classifier layer.
ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

        
# 2. Pruner initialization
iterative_steps = 5 # You can prune your model to the target sparsity iteratively.
pruner = tp.pruner.MagnitudePruner(
    model, 
    example_inputs, 
    global_pruning=False, # If False, a uniform sparsity will be assigned to different layers.
    importance=imp, # importance criterion for parameter selection
    iterative_steps=iterative_steps, # the number of iterations to achieve target sparsity
    ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)
# print(model)
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    # 3. the pruner.step will remove some channels from the model with least importance
    pruner.step()
    
    # 4. Do whatever you like here, such as fintuning
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    # print(model)
    print(model(example_inputs).shape)
    print(
        "  Iter %d/%d, Params: %.2f M => %.2f M"
        % (i+1, iterative_steps, base_nparams / 1e6, nparams / 1e6)
    )
    print(
        "  Iter %d/%d, MACs: %.2f G => %.2f G"
        % (i+1, iterative_steps, base_macs / 1e9, macs / 1e9)
    )
    metrics = {
        "model": model,
        "example_input": example_input,
    }
    print(eval.get_size(model), "MB")
    print(len(list(model.children())), "children")
    # finetune your model here
    # finetune(model)
    # ...
print(list(model.modules()))

torch.Size([1, 1000])
  Iter 1/5, Params: 11.69 M => 9.48 M
  Iter 1/5, MACs: 1.82 G => 1.47 G
36.24 MB
10 children
torch.Size([1, 1000])
  Iter 2/5, Params: 11.69 M => 7.53 M
  Iter 2/5, MACs: 1.82 G => 1.18 G
28.8 MB
10 children
torch.Size([1, 1000])
  Iter 3/5, Params: 11.69 M => 5.82 M
  Iter 3/5, MACs: 1.82 G => 0.91 G
22.26 MB
10 children
torch.Size([1, 1000])
  Iter 4/5, Params: 11.69 M => 4.32 M
  Iter 4/5, MACs: 1.82 G => 0.68 G
16.53 MB
10 children
torch.Size([1, 1000])
  Iter 5/5, Params: 11.69 M => 3.06 M
  Iter 5/5, MACs: 1.82 G => 0.49 G
11.71 MB
10 children
[ResNet(
  (conv1): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=

### Attribution Scores
Calculating the importance of different modules

In [None]:
from torchpruner.attributions import RandomAttributionMetric  # or any of the methods above

attr = RandomAttributionMetric(model, test_loader, criterion, device)
for module in model.children():
    if len(list(module.children())) == 0:  # leaf module
        scores = attr.run(module)
        print(module)
        print (len(scores))      

Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
20
Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
50
Linear(in_features=800, out_features=500, bias=True)
500
Linear(in_features=500, out_features=10, bias=True)
10


### Any Pruning

In [None]:
import torch
from torchvision.models import resnet18
import torch_pruning as tp

from torchvision.models.resnet import ResNet18_Weights

# 0. prepare your model and example inputs
model = resnet18(weights=ResNet18_Weights.DEFAULT).eval()

example_input = torch.randn(1,3,224,224)

# eval.get_model_flops(model, example_input)
print(eval.get_size(model), "MB")

# 1. build dependency graph for resnet18
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=example_input)

# 2. Specify the to-be-pruned channels. Here we prune those channels indexed by [2, 6, 9].
pruning_idxs = [2, 6, 9]
pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )

# 3. prune all grouped layer that is coupled with model.conv1 (included).
if DG.check_pruning_group(pruning_group): # avoid full pruning, i.e., channels=0.
    pruning_group.exec()


print(eval.get_size(model), "MB")


# 4. save & load the pruned model 
torch.save(model, 'model.pth') # save the model object
model_loaded = torch.load('model.pth') # no load_state_dict

44.66 MB
44.62 MB


In [None]:
for x in pruning_group:
    print(x)

([DEP] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), [2, 6, 9])
([DEP] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), [2, 6, 9])
([DEP] prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp(ReluBackward0), [2, 6, 9])
([DEP] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_out_channels on _ElementWiseOp(MaxPool2DWithIndicesBackward0), [2, 6, 9])
([DEP] prune_out_channels on _ElementWiseOp(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp(AddBackward0), [2, 6, 9])
([DEP] prune_out_channels on _ElementWiseOp(MaxPool2DWithIndic

In [15]:
import re
import inspect
from typing import Tuple, Union

def get_first_last_layers(model: torch.nn.Module) -> Tuple[torch.nn.Module, torch.nn.Module]:
    forward_source = inspect.getsource(model.forward)
    lines = forward_source.split('\n')[1:]  # Exclude the method definition line

    print(lines)
    
    layer_names = []
    for line in lines:
        match = re.search(r'self\.(\w+)', line)
        if match:
            layer_names.append(match.group(1))
    
    if not layer_names:
        raise ValueError('No layers found in the forward method')
    
    first_layer = getattr(model, layer_names[0])
    last_layer = getattr(model, layer_names[-1])
    
    return first_layer, last_layer

# Example usage
first_layer, last_layer = get_first_last_layers(model)
print('First layer called:', first_layer)
print('Last layer called:', last_layer)


['        return self._forward_impl(x)', '']
First layer called: <bound method ResNet._forward_impl of ResNet(
  (conv1): Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(