# Chapter-6 Model Pruning in Pytorch and exporting to ONNX

#### In this notebook, we will try to prune ResNet50 model in Pytorch. We will explore different pruning techniques to prune the model and measure the performance of the model. We will then export the pruned model to the ONNX.

### Step-1: Train a ResNet50 model on Cats v/s Dogs dataset

We will be using Cats v/s Dogs dataset from huggingface. 
Link: https://huggingface.co/datasets/microsoft/cats_vs_dogs

<div style="text-align: center;">
    <h3>Architecture of ResNet50 model</h3>
    <img src="https://www.researchgate.net/profile/Master-Prince/publication/350421671/figure/fig1/AS:1005790324346881@1616810508674/An-illustration-of-ResNet-50-layers-architecture.png" width="400" height="300">
</div>



In [None]:
# Install dependencies
!pip install onnx==1.18.0 onnxruntime==1.22.0 netron==8.4.3
!pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 \
    --index-url https://download.pytorch.org/whl/cu124
!pip install datasets==4.0.0

In [1]:
import os
import copy
import torch
import torch.nn as nn
from tqdm.auto import tqdm
import torch.optim as optim
from torch.nn.utils import prune
import torchvision.models as models
import torchvision.transforms as T
from datasets import load_dataset
from torchvision import transforms
from collections import defaultdict
from torch.utils.data import DataLoader, random_split

torch.manual_seed(42)

<torch._C.Generator at 0x7fca940d44b0>

In [2]:
# Setup Dataset object for training

class CatDogDataset(torch.utils.data.Dataset):
    # Load dataset from Hugging Face
    dataset = load_dataset("microsoft/cats_vs_dogs")
    
    # Split dataset into train and test sets
    train_size = int(0.8 * len(dataset["train"]))
    test_size = len(dataset["train"]) - train_size
    train_dataset, test_dataset = random_split(dataset["train"], [train_size, test_size])

    def __init__(self, is_train=True):
        if is_train:
            self.ds = self.__class__.train_dataset
            self.transform_fn = self.transforms()
        else:
            self.ds = self.__class__.test_dataset
            self.transform_fn = self.transforms()

    def transforms(self):
        transform = transforms.Compose(
            [
                T.Resize((224, 224)),
                T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),  # Convert grayscale to RGB
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),      # ImageNet normalization
            ]
        )
        return transform

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        example = self.ds[idx]
        img = example["image"]
        example.pop("image")
        example["images"] = self.transform_fn(img)
        example["labels"] = torch.tensor(example["labels"], dtype=torch.long)
        return example

In [3]:
# Setup ResNet50 classifier model

class ResNetClassifier(nn.Module):
    def __init__(self, ckpt_path=None):
        super(ResNetClassifier, self).__init__()

        # Load pretrained resnet50
        self.model = models.resnet50(pretrained=True)

        # Update the last FC layer for 2 class classification problem
        self.model.fc = nn.Linear(self.model.fc.in_features, 2)
        self.loss_fn = nn.CrossEntropyLoss()

        # Load the checkpoint if it is provided.
        if ckpt_path is not None:
            if not os.path.isfile(ckpt_path):
                raise FileNotFoundError(f"File not present at: {ckpt_path}.")
            self.model.load_state_dict(torch.load(ckpt_path))

    def forward(self, images, labels=None):
        logits = self.model(images)
        loss = self.loss_fn(logits, labels) if labels is not None else None
        return (loss, logits) if loss is not None else logits

In [4]:
# Setup the training class

class Trainer:
    def __init__(self, model, train_dataset, test_dataset, batch_size, num_workers, 
                num_epochs, lr, device, model_name):
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.num_epochs = num_epochs
        self.lr = lr
        self.device = device
        self.model_name = model_name

        self.train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
        )
        self.test_loader = DataLoader(
            test_dataset, batch_size=batch_size, num_workers=num_workers
        )
        self.model = model
        self.model = self.model.to(self.device)

    def fit(self):
        # Loss and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

        # Training loop
        for epoch in range(self.num_epochs):
            self.model.train()
            total_loss = 0
            correct = 0
            total = 0
            for batch in tqdm(self.train_loader):
                images, labels = batch["images"].to(self.device), batch["labels"].to(self.device)
                optimizer.zero_grad()
                outputs = self.model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                _, predicted = outputs.max(1)
                correct += predicted.eq(labels).sum().item()
                total += labels.size(0)
            print(f"Epoch {epoch+1}/{self.num_epochs}, Loss: {total_loss/len(self.train_loader)}, Accuracy: {correct/total*100:.2f}%")

            # Test loop
            self.model.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for batch in tqdm(self.test_loader):
                    images, labels = batch["images"].to(self.device), batch["labels"].to(self.device)
                    outputs = self.model(images)
                    _, predicted = outputs.max(1)
                    correct += predicted.eq(labels).sum().item()
                    total += labels.size(0)
            print(f"Test Accuracy: {correct/total*100:.2f}%")

        # Save Model
        torch.save(self.model.state_dict(), f"{self.model_name.replace('.', '_').replace('/', '_')}.pth")

    def test(self):
        # Test loop
        self.model.to(self.device)
        self.model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for batch in tqdm(self.test_loader):
                images, labels = batch["images"].to(self.device), batch["labels"].to(
                    self.device
                )
                outputs = self.model(images)
                _, predicted = outputs.max(1)
                correct += predicted.eq(labels).sum().item()
                total += labels.size(0)
        acc = correct / total * 100
        print(f"Test Accuracy: {acc:.2f}%")
        return acc

In [5]:
# Define Hyperparameters

BATCH_SIZE = 32
LR = 1e-3
EPOCHS = 3
NUM_WORKERS = 4  # Adjust based on CPU cores
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "resnet50_cats_dogs"

In [6]:
# Setup dataloader
train_dataset = CatDogDataset(is_train=True)
test_dataset = CatDogDataset(is_train=False)

In [7]:
# Train model on Cats v/s Dogs dataset.
base_model = ResNetClassifier(None).to(DEVICE)

base_trainer = Trainer(
    base_model, train_dataset, test_dataset, BATCH_SIZE, NUM_WORKERS, 
    EPOCHS, LR, DEVICE, MODEL_NAME
)
base_trainer.fit()
base_acc = base_trainer.test()
print(f"Accuracy before pruning: {base_acc:.2f}%")



  0%|          | 0/586 [00:00<?, ?it/s]

Epoch 1/3, Loss: 0.23909997593598764, Accuracy: 89.91%


  0%|          | 0/147 [00:00<?, ?it/s]

Test Accuracy: 90.24%


  0%|          | 0/586 [00:00<?, ?it/s]

Epoch 2/3, Loss: 0.15712740046127918, Accuracy: 93.75%


  0%|          | 0/147 [00:00<?, ?it/s]

Test Accuracy: 91.03%


  0%|          | 0/586 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>
Traceback (most recent call last):
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>
Traceback (most recent call last):
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    se

Epoch 3/3, Loss: 0.1338720278788212, Accuracy: 94.49%


  0%|          | 0/147 [00:00<?, ?it/s]

Test Accuracy: 95.39%


  0%|          | 0/147 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>
Traceback (most recent call last):
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>
Traceback (most recent call last):
      File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
    AssertionErrorself._shutdown_workers(): can only

Test Accuracy: 95.39%
Accuracy before pruning: 95.39%


### Step-2: Define some helper functions

In [8]:
# Compute overall model sparsity for unstructured pruning
def compute_global_sparsity(model):
    total_params = 0
    zero_params = 0

    for name, module in model.named_modules():
        if isinstance(module, nn.CrossEntropyLoss):
            continue
        if hasattr(module, "weight"):
            total_params += module.weight.nelement()  # Total elements
            zero_params += torch.sum(module.weight == 0).item()  # Zero elements
        elif hasattr(module, "bias") and module.bias is not None:
            total_params += module.bias.nelement()  # Total elements
            zero_params += torch.sum(module.bias == 0).item()  # Zero elements

    sparsity = zero_params / total_params * 100
    return sparsity

# Compute overall model's structured sparsity against base model
def compute_structured_sparsity(base_model, pruned_model):
    def compute_params(model):
        model_params = 0
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.CrossEntropyLoss):
                continue
            if hasattr(module, "weight"):
                model_params += module.weight.nelement()  # Total elements
            elif hasattr(module, "bias") and module.bias is not None:
                model_params += module.bias.nelement()  # Total elements
        return model_params

    base_model_params = compute_params(base_model)
    pruned_model_params = compute_params(pruned_model)

    sparsity = pruned_model_params / base_model_params
    return sparsity, base_model_params, pruned_model_params

In [9]:
# Helper utility to export the model

def export_to_onnx(model, output_path):
    # Dummy input tensor for ONNX export
    dummy_input = torch.randn(1, 3, 224, 224)
    model.to("cpu")

    # Export the pruned model to ONNX
    torch.onnx.export(
        model, dummy_input, output_path,
        input_names=["input"], output_names=["output"],
        dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
        opset_version=11
    )

    print(f"Model has been exported to {output_path}")

In [10]:
# Export original unpruned model
export_to_onnx(base_model, "./resnet50_cats_vs_dogs_unpruned.onnx")

Model has been exported to ./resnet50_cats_vs_dogs_unpruned.onnx


### Step-3: Apply Unstructured Pruning on trained ResNet50 model

In [11]:
# We will prune away 40 % of the weights.
unstructured_prune_ratio = 0.4

In [12]:
# We will apply unstructured pruning to all the convolution layer weights.

def apply_unstructured_pruning(model, prune_ratio=0.2):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=prune_ratio)
    
    return model

def update_model_post_pruning(model):
    # To remove pruning reparameterization and store permanently
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.remove(module, 'weight')
    
    return model

In [13]:
# Pruned model
unstructured_pruned_model = copy.deepcopy(base_model)
unstructured_pruned_model = apply_unstructured_pruning(unstructured_pruned_model, unstructured_prune_ratio)
unstructured_pruned_model_name = "resnet50_cats_dogs_unstructured_prune"

unstructured_prune_trainer = Trainer(
    unstructured_pruned_model, train_dataset, test_dataset, BATCH_SIZE, NUM_WORKERS, 
    EPOCHS, LR, DEVICE, unstructured_pruned_model_name
)
unstructured_pruned_acc = unstructured_prune_trainer.test()
model_sparsity = compute_global_sparsity(unstructured_pruned_model)

print(f"Unstructured pruning ratio: {unstructured_prune_ratio:.2f}")
print(f"Model's overall sparsity: {model_sparsity:.2f} %")
print(f"Accuracy after applying unstructured pruning: {unstructured_pruned_acc:.2f}%")

  0%|          | 0/147 [00:00<?, ?it/s]

Test Accuracy: 95.11%
Unstructured pruning ratio: 0.40
Model's overall sparsity: 39.95 %
Accuracy after applying unstructured pruning: 95.11%


#### The pruned model has 40% of its weights set to zero, yet the accuracy remains largely unchanged. It dropped only slightly from 95.39% to 95.11%, resulting in a minimal 0.28% decrease. This indicates that the ResNet50 architecture is highly overparameterized for a relatively simple task like the Cats vs. Dogs classification. Such a minimal accuracy loss does not necessitate finetuning the model post-pruning. However, for more complex use cases, finetuning may be required after pruning to recover any significant loss in performance.

In [14]:
# Check weights of the original model for one of the layer of model.

print("Weights before pruning")
print(base_model.model.layer1[0].conv1.weight.data[:3, :3, :, :])
print("\nShape of the weight before pruning")
print(base_model.model.layer1[0].conv1.weight.data.shape)

Weights before pruning
tensor([[[[-0.0064]],

         [[ 0.0312]],

         [[-0.0454]]],


        [[[ 0.1461]],

         [[-0.0112]],

         [[ 0.0454]]],


        [[[-0.0481]],

         [[-0.0449]],

         [[-0.0237]]]])

Shape of the weight before pruning
torch.Size([64, 64, 1, 1])


In [15]:
# Check weights of the pruned model for one of the layer of model.

print("Weight mask for unstructured pruning")
print(unstructured_pruned_model.model.layer1[0].conv1.weight_mask.data[:3, :3, :, :])
print("\nWeights after unstructured pruning")
print(unstructured_pruned_model.model.layer1[0].conv1.weight.data[:3, :3, :, :])
print("\nShape of the weight after pruning")
print(unstructured_pruned_model.model.layer1[0].conv1.weight.data.shape)

Weight mask for unstructured pruning
tensor([[[[0.]],

         [[1.]],

         [[1.]]],


        [[[1.]],

         [[0.]],

         [[1.]]],


        [[[1.]],

         [[1.]],

         [[1.]]]], device='cuda:0')

Weights after unstructured pruning
tensor([[[[-0.0000]],

         [[ 0.0312]],

         [[-0.0454]]],


        [[[ 0.1461]],

         [[-0.0000]],

         [[ 0.0454]]],


        [[[-0.0481]],

         [[-0.0449]],

         [[-0.0237]]]], device='cuda:0')

Shape of the weight after pruning
torch.Size([64, 64, 1, 1])


In [16]:
# Remove the weight masks and update the weights inplace.
unstructured_pruned_model = update_model_post_pruning(unstructured_pruned_model)

In [17]:
# Export unstructured pruned model
export_to_onnx(unstructured_pruned_model, "./resnet50_cats_vs_dogs_unstructured_pruned.onnx")

Model has been exported to ./resnet50_cats_vs_dogs_unstructured_pruned.onnx


### Step-4: Apply Structured Pruning on trained ResNet50 model

In [18]:
# Defining some helper functions

def update_remaining_channels(name, module, channel_mask, in_channel=True):
    mask = module.weight_mask
    channel = "input" if in_channel else "output"
    weight_shape = len(mask.shape)
    if in_channel:
        if weight_shape == 4:
            # out_channel, in_channel, k_height, k_width --> Conv layer
            dim = [0, 2, 3]
        else:
            # out_channel, in_channel --> FC layer
            dim = [0]
        remaining_out_channels = torch.sum(mask, dim=dim) > 0
    else:
        if weight_shape == 4:
            # out_channel, in_channel, k_height, k_width --> Conv layer
            dim = [1, 2, 3]
        else:
            # out_channel, in_channel --> FC layer
            dim = [1]
        remaining_out_channels = torch.sum(mask, dim=dim) > 0
    channel_mask[name][channel] = remaining_out_channels


def prune_weights(name, module, prune_ratio, channel_mask, in_channel=True):
    dim = 1 if in_channel else 0
    # ln_structured will identify redundant input or output channels from the weights.
    prune.ln_structured(module, name="weight", amount=prune_ratio, n=1, dim=dim)
    update_remaining_channels(name, module, channel_mask, in_channel)


#### Here we are only pruning the layer4 and FC layer of the ResNet50 model. The code applies the structured pruning to the out channels of the first conv layer of layer4. Due to change in the out channels, the next layer will become incompatible to this layer. Hence, we will cascade this change and prune the next layer's input channels as well. Once we do that we will also update next layer's output channels to further prune the model. We will apply cascading changes to all the subsequent layers.


In [19]:
# Helper function to apply structured pruning. 

def apply_structured_pruning(base_model, prune_ratio):
    pruned_model = copy.deepcopy(base_model)

    # Apply structured pruning and store pruned channels
    channel_mask = defaultdict(dict)  # Dictionary to store pruned channels per layer
    module_dict = {}

    for name, module in pruned_model.model.named_modules():
        if (
            name == "conv1"
            or name.startswith("layer1")
            or name.startswith("layer2")
            or name.startswith("layer3")
        ):
            continue
        module_dict[name] = module
        if isinstance(module, nn.Conv2d):
            # Update out channel
            prune_weights(name, module, prune_ratio, channel_mask, False)
            if (
                name != "conv1"
                and name != "layer4.0.conv1"
                and name != "layer4.0.downsample.0"
            ):
                # Update in channel
                prune_weights(name, module, prune_ratio, channel_mask, True)

        elif isinstance(module, nn.Linear):
            if name == "fc":
                # Update in channel
                prune_weights(name, module, prune_ratio, channel_mask, True)

    existing_modules = list(pruned_model.model.named_modules())
    for name, module in tqdm(existing_modules):
        if (
            name == "conv1"
            or name == "bn1"
            or name.startswith("layer1")
            or name.startswith("layer2")
            or name.startswith("layer3")
        ):
            continue
        if name in channel_mask and isinstance(module, nn.Conv2d):
            channel_dict = channel_mask[name]
            if "input" in channel_dict:
                new_in_channels = (
                    channel_dict["input"].sum().item()
                )  # Count remaining filters
            else:
                new_in_channels = module.in_channels

            if "output" in channel_dict:
                new_out_channels = (
                    channel_dict["output"].sum().item()
                )  # Count remaining filters
            else:
                new_out_channels = module.out_channels

            new_conv2d = nn.Conv2d(
                in_channels=new_in_channels,
                out_channels=new_out_channels,
                kernel_size=module.kernel_size,
                stride=module.stride,
                padding=module.padding,
                dilation=module.dilation,
                groups=module.groups,
                bias=module.bias,
                padding_mode=module.padding_mode,
            )
            if "input" in channel_dict:
                new_conv2d.weight.data = module.weight.data[
                    :, channel_dict["input"], :, :
                ].clone()

            if "output" in channel_dict:
                if "input" in channel_dict:
                    new_conv2d.weight.data = new_conv2d.weight.data[
                        channel_dict["output"]
                    ].clone()
                    if module.bias is not None:
                        new_conv2d.bias.data = new_conv2d.bias.data[
                            channel_dict["output"]
                        ].clone()
                else:
                    new_conv2d.weight.data = module.weight.data[
                        channel_dict["output"]
                    ].clone()
                    if module.bias is not None:
                        new_conv2d.bias.data = module.bias.data[
                            channel_dict["output"]
                        ].clone()

            if name.startswith("layer"):
                splits = name.split(".")
                layer_idx, sub_layer_idx, conv_name = splits[:3]
                if len(splits) == 4:
                    downsample_idx = splits[3]
                    par_module = getattr(
                        getattr(getattr(pruned_model.model, layer_idx), sub_layer_idx),
                        conv_name,
                    )
                    setattr(par_module, downsample_idx, new_conv2d)
                else:
                    par_module = getattr(
                        getattr(pruned_model.model, layer_idx), sub_layer_idx
                    )
                    setattr(par_module, conv_name, new_conv2d)
            else:
                setattr(pruned_model.model, name, new_conv2d)

        if isinstance(module, nn.BatchNorm2d):
            if "downsample" in name:
                conv_name = name.replace("downsample.1", "downsample.0")
            else:
                conv_name = name.replace("bn", "conv")

            channel_dict = channel_mask[conv_name]
            if "output" in channel_dict:
                new_out_channels = (
                    channel_dict["output"].sum().item()
                )  # Count remaining filters

                new_bn = nn.BatchNorm2d(
                    num_features=new_out_channels,
                    eps=module.eps,
                    momentum=module.momentum,
                    affine=module.affine,
                )

                new_bn.running_mean.data = module.running_mean.data[
                    channel_dict["output"]
                ].clone()
                new_bn.running_var.data = module.running_var.data[
                    channel_dict["output"]
                ].clone()
                if module.affine:
                    new_bn.weight.data = module.weight.data[
                        channel_dict["output"]
                    ].clone()
                    new_bn.bias.data = module.bias.data[channel_dict["output"]].clone()

                if name.startswith("layer"):
                    splits = name.split(".")
                    layer_idx, sub_layer_idx, bn_name = splits[:3]
                    if len(splits) == 4:
                        downsample_idx = splits[3]
                        par_module = getattr(
                            getattr(
                                getattr(pruned_model.model, layer_idx), sub_layer_idx
                            ),
                            bn_name,
                        )
                        setattr(par_module, downsample_idx, new_bn)
                    else:
                        par_module = getattr(
                            getattr(pruned_model.model, layer_idx), sub_layer_idx
                        )
                        setattr(par_module, bn_name, new_bn)
                else:
                    setattr(pruned_model.model, name, new_bn)

        if isinstance(module, nn.Linear):
            channel_dict = channel_mask[name]
            if "input" in channel_dict:
                new_in_channels = (
                    channel_dict["input"].sum().item()
                )  # Count remaining filters
            else:
                new_in_channels = module.in_features

            if "output" in channel_dict:
                new_out_channels = (
                    channel_dict["output"].sum().item()
                )  # Count remaining filters
            else:
                new_out_channels = module.out_features

            new_linear = nn.Linear(
                in_features=new_in_channels,
                out_features=new_out_channels,
                bias=module.bias is not None,
            )

            if "input" in channel_dict:
                new_linear.weight.data = module.weight.data[
                    :, channel_dict["input"]
                ].clone()

            if "output" in channel_dict:
                if "input" in channel_dict:
                    new_linear.weight.data = new_linear.weight.data[
                        channel_dict["output"]
                    ].clone()
                    if module.bias is not None:
                        new_linear.bias.data = new_linear.bias.data[
                            channel_dict["output"]
                        ].clone()
                else:
                    new_linear.weight.data = module.weight.data[
                        channel_dict["output"]
                    ].clone()
                    if module.bias is not None:
                        new_linear.bias.data = module.bias.data[
                            channel_dict["output"]
                        ].clone()

            if name.startswith("layer"):
                splits = name.split(".")
                layer_idx, sub_layer_idx, linear_name = splits
                par_module = getattr(
                    getattr(pruned_model.model, layer_idx), sub_layer_idx
                )
                setattr(par_module, linear_name, new_linear)
            else:
                setattr(pruned_model.model, name, new_linear)

    sparsity, base_param_cnt, prune_param_cnt = compute_structured_sparsity(
        base_model, pruned_model
    )
    return pruned_model, sparsity * 100, base_param_cnt, prune_param_cnt


#### We use iterative pruning for the Structured Pruning case because, as we'll observe, the accuracy drop is significant and requires recovery through retraining. In this approach, we prune 10% of the parameters from the layer4 and FC layers of the ResNet50 model at each step. After each pruning step, we retrain the model for 1 epoch. This process will be repeated for 3 iterations.

In [20]:
pruning_iters = 3
prune_ratio = 0.1

model = base_model
for i in tqdm(range(pruning_iters)):
    print(f"[Step-{i+1}] Pruning the model. Prune ratio={prune_ratio}")
    model = model.to("cpu")
    pruned_model, prune_percentage, base_param_cnt, prune_param_cnt = (
        apply_structured_pruning(model, prune_ratio=prune_ratio)
    )
    print(f"[Step-{i+1}] Base model params: {base_param_cnt:,}")
    print(f"[Step-{i+1}] Pruned model params: {prune_param_cnt:,}")
    print(f"[Step-{i+1}] %age params after pruning: {prune_percentage:.2f} %")

    pruned_model.to(DEVICE)
    dummy_input = torch.randn(1, 3, 224, 224).to(torch.float32).to(DEVICE)
    pruned_model(dummy_input)

    # Measure accuracy
    PRUNED_MODEL_NAME = f"{MODEL_NAME}_pruned_iter_{i+1}"
    prune_trainer = Trainer(
        pruned_model,
        train_dataset,
        test_dataset,
        BATCH_SIZE,
        NUM_WORKERS,
        1,
        LR,
        DEVICE,
        PRUNED_MODEL_NAME,
    )
    prune_acc = prune_trainer.test()
    print(f"[Step-{i+1}] Accuracy after pruning: {prune_acc:.2f}%")

    # Recovery training
    prune_trainer.fit()
    prune_acc = prune_trainer.test()
    print(f"[Step-{i+1}] Accuracy after re-training: {prune_acc:.2f}%")

    model = pruned_model

structured_pruned_model = model

  0%|          | 0/3 [00:00<?, ?it/s]

[Step-1] Pruning the model. Prune ratio=0.1


  0%|          | 0/151 [00:00<?, ?it/s]

[Step-1] Base model params: 23,485,568
[Step-1] Pruned model params: 20,887,302
[Step-1] %age params after pruning: 88.94 %


  0%|          | 0/147 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>
Exception ignored in: Traceback (most recent call last):
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>
Traceback (most recent call last):
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    self._shutdown_workers()
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
      F

Test Accuracy: 79.50%
[Step-1] Accuracy after pruning: 79.50%


  0%|          | 0/586 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>
Traceback (most recent call last):
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090><function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packag

Epoch 1/1, Loss: 0.1366160377675712, Accuracy: 94.61%


  0%|          | 0/147 [00:00<?, ?it/s]

Test Accuracy: 92.18%


  0%|          | 0/147 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>
Traceback (most recent call last):
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>
Traceback (most recent call last):
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    se

Test Accuracy: 92.18%
[Step-1] Accuracy after re-training: 92.18%
[Step-2] Pruning the model. Prune ratio=0.1


  0%|          | 0/151 [00:00<?, ?it/s]

[Step-2] Base model params: 20,887,302
[Step-2] Pruned model params: 18,756,720
[Step-2] %age params after pruning: 89.80 %


  0%|          | 0/147 [00:00<?, ?it/s]

Test Accuracy: 49.62%
[Step-2] Accuracy after pruning: 49.62%


  0%|          | 0/586 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>
Traceback (most recent call last):
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>
Traceback (most recent call last):
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    se

Epoch 1/1, Loss: 0.1307951264622154, Accuracy: 94.81%


  0%|          | 0/147 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>
Traceback (most recent call last):
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>
Traceback (most recent call last):
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
      File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
self._shutdown_workers()
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
        if w.is_alive():
if 

Test Accuracy: 94.77%


  0%|          | 0/147 [00:00<?, ?it/s]

Test Accuracy: 94.77%
[Step-2] Accuracy after re-training: 94.77%
[Step-3] Pruning the model. Prune ratio=0.1


  0%|          | 0/151 [00:00<?, ?it/s]

[Step-3] Base model params: 18,756,720
[Step-3] Pruned model params: 16,990,908
[Step-3] %age params after pruning: 90.59 %


  0%|          | 0/147 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>
Traceback (most recent call last):
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>Traceback (most recent call last):

  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
Traceback (most recent call last):
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
            self._shutdown_workers()self._shutdown_workers()

  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib

Test Accuracy: 49.62%
[Step-3] Accuracy after pruning: 49.62%


  0%|          | 0/586 [00:00<?, ?it/s]

Epoch 1/1, Loss: 0.11634316487768633, Accuracy: 95.30%


  0%|          | 0/147 [00:00<?, ?it/s]

Test Accuracy: 94.40%


  0%|          | 0/147 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>
Traceback (most recent call last):
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>
Traceback (most recent call last):
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    se

Test Accuracy: 94.40%
[Step-3] Accuracy after re-training: 94.40%


In [21]:
structured_pruned_model_name = "resnet50_cats_dogs_structured_prune"

structured_prune_trainer = Trainer(
    structured_pruned_model, train_dataset, test_dataset, BATCH_SIZE, NUM_WORKERS, 
    EPOCHS, LR, DEVICE, structured_pruned_model_name
)
structured_pruned_acc = structured_prune_trainer.test()

model_sparsity, base_model_params, pruned_model_params = compute_structured_sparsity(base_model, structured_pruned_model)

print(f"Base model params count: {base_model_params:,}")
print(f"Structurally Pruned model params count: {pruned_model_params:,}")
print(f"Model's overall sparsity: {model_sparsity*100:.2f} %")
print(f"Accuracy after applying unstructured pruning: {structured_pruned_acc:.2f}%")

  0%|          | 0/147 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>
Traceback (most recent call last):
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc9f8931090>
Traceback (most recent call last):
    self._shutdown_workers()
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
        if w.is_alive():self._shutdown_workers()

  File "/mnt/d/Meet/Company/Orange Eva Publication/Jupyter Notebook/onnx_env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
  File 

Test Accuracy: 94.40%
Base model params count: 23,485,568
Structurally Pruned model params count: 16,990,908
Model's overall sparsity: 72.35 %
Accuracy after applying unstructured pruning: 94.40%


#### After performing structured pruning, the model's parameter count decreased from 23,485,568 to 16,990,908, representing a reduction of approximately 28%. Additionally, the model's accuracy decreased slightly from 95.39% to 94.40%, a drop of just 1.01%. This decrease in accuracy is minimal when compared to the 28% reduction in the number of parameters.

In [22]:
# Check weights of the original model for layer4 of model.

print("Weights before pruning")
print(base_model.model.layer4[0].conv1.weight.data[:3, :3, :, :])
print("\nShape of the weight before pruning")
print(base_model.model.layer4[0].conv1.weight.data.shape)

Weights before pruning
tensor([[[[-2.4183e-03]],

         [[ 1.6129e-02]],

         [[ 1.1725e-02]]],


        [[[ 4.3987e-02]],

         [[ 3.0392e-02]],

         [[ 1.8671e-02]]],


        [[[-3.4788e-02]],

         [[ 3.7122e-03]],

         [[ 1.9509e-05]]]])

Shape of the weight before pruning
torch.Size([512, 1024, 1, 1])


In [23]:
# Check weights of the pruned model for layer4 of model.

print("Weights after structured pruning")
print(structured_pruned_model.model.layer4[0].conv1.weight.data[:3, :3, :, :])
print("\nShape of the weight after pruning")
print(structured_pruned_model.model.layer4[0].conv1.weight.data.shape)

print("\nShape of the weight after layer4 conv in pruned model")
print(structured_pruned_model.model.layer4[0].conv2.weight.data.shape)

Weights after structured pruning
tensor([[[[-0.0286]],

         [[-0.0179]],

         [[ 0.0131]]],


        [[[-0.0385]],

         [[ 0.0290]],

         [[-0.0172]]],


        [[[ 0.0439]],

         [[-0.0122]],

         [[-0.0087]]]], device='cuda:0')

Shape of the weight after pruning
torch.Size([373, 1024, 1, 1])

Shape of the weight after layer4 conv in pruned model
torch.Size([373, 373, 3, 3])


#### The weight shape follows the format [out_channels, in_channels, kernel_height, kernel_width]. After pruning, the number of out_channels has decreased from 512 to 373. Consequently, we've also adjusted the input channels of the subsequent convolution layer.

In [24]:
# Export unstructured pruned model
export_to_onnx(structured_pruned_model, "./resnet50_cats_vs_dogs_structured_pruned.onnx")

Model has been exported to ./resnet50_cats_vs_dogs_structured_pruned.onnx


<div style="text-align: center;">
    <h3>Architecture of ResNet50 model</h3>
</div>


<table>
  <tr>
    <td style="padding-right: 30px;">
        <img src="https://www.researchgate.net/profile/Master-Prince/publication/350421671/figure/fig1/AS:1005790324346881@1616810508674/An-illustration-of-ResNet-50-layers-architecture.png" width="400" height="300"><b><center>Before Pruning</b></center>
    </td>
    <td style="left: 30px;">
        <img src="ResNet50_Architecture_Pruned.png" width="400"><br><center><b>After Pruning</b></center>
    </td>
  </tr>
</table>

**Note:** The images above show the changes in model architecture before and after pruning. As we can see, the model's last set of layers in purple has changed after the application of pruning. This can be visualized in ONNX model as well.