# Efficient Machine Learning - Fine-grained and channel pruning
This notebook concerns the concept of neural network pruning. The concepts of fine-grained and channel pruning are implemented and tested. The performance improvements and differences and tradeoffs between these pruning approaches are compared

## Model and setup
We will conduct the experiments on VGG16. The model is quite outdated by today standards, however, it is easily dissectable and there are verious pretrained variants available. This makes it suitable for the purpose of this notebook.

In [1]:
%pip install torchprofile

Collecting torchprofile
  Downloading torchprofile-0.0.4-py3-none-any.whl.metadata (303 bytes)
Downloading torchprofile-0.0.4-py3-none-any.whl (7.7 kB)
Installing collected packages: torchprofile
Successfully installed torchprofile-0.0.4
Note: you may need to restart the kernel to use updated packages.


In [2]:
import copy
import math
import random
import time
from collections import OrderedDict, defaultdict
from typing import Union, List

import numpy as np
import torch
import torch.nn as nn
from torch.optim import *
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader
from torchprofile import profile_macs
from torchvision.datasets import *
from torchvision.transforms import *
from torchvision.models import vgg16
from tqdm.auto import tqdm

# Ensure CUDA support
assert torch.cuda.is_available(), "The runtime has no CUDA support"

In [3]:
# setting seeds for reproducability
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7c0a0ec63410>

The model architecture is the same as vgg11_bn in torchvision - this enables us to use the weights pretrained on imagenet_1k. This class definition is used to add explicit layer names to the model.

In [4]:
class VGG(nn.Module):
#   ARCH = [64, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
  ARCH =  [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"]

  def __init__(self) -> None:
    super().__init__()

    layers = []
    counts = defaultdict(int)

    def add(name: str, layer: nn.Module) -> None:
      layers.append((f"{name}{counts[name]}", layer))
      counts[name] += 1

    in_channels = 3
    for x in self.ARCH:
      if x != 'M':
        # conv-bn-relu
        add("conv", nn.Conv2d(in_channels, x, 3, padding=1, bias=False))
        add("bn", nn.BatchNorm2d(x))
        add("relu", nn.ReLU(True))
        in_channels = x
      else:
        # maxpool
        add("pool", nn.MaxPool2d(2))

    self.backbone = nn.Sequential(OrderedDict(layers))
    self.classifier = nn.Linear(512, 10)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    # backbone: [N, 3, 32, 32] => [N, 512, 2, 2]
    x = self.backbone(x)

    # avgpool: [N, 512, 2, 2] => [N, 512]
    x = x.mean([2, 3])

    # classifier: [N, 512] => [N, 10]
    x = self.classifier(x)
    return x

In [5]:
model = VGG().cuda()

In [6]:
model

VGG(
  (backbone): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu2): ReLU(inplace=True)
    (conv3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, a

In [7]:
for name, param in model.named_parameters():
    print(name, param.size())

backbone.conv0.weight torch.Size([64, 3, 3, 3])
backbone.bn0.weight torch.Size([64])
backbone.bn0.bias torch.Size([64])
backbone.conv1.weight torch.Size([128, 64, 3, 3])
backbone.bn1.weight torch.Size([128])
backbone.bn1.bias torch.Size([128])
backbone.conv2.weight torch.Size([256, 128, 3, 3])
backbone.bn2.weight torch.Size([256])
backbone.bn2.bias torch.Size([256])
backbone.conv3.weight torch.Size([256, 256, 3, 3])
backbone.bn3.weight torch.Size([256])
backbone.bn3.bias torch.Size([256])
backbone.conv4.weight torch.Size([512, 256, 3, 3])
backbone.bn4.weight torch.Size([512])
backbone.bn4.bias torch.Size([512])
backbone.conv5.weight torch.Size([512, 512, 3, 3])
backbone.bn5.weight torch.Size([512])
backbone.bn5.bias torch.Size([512])
backbone.conv6.weight torch.Size([512, 512, 3, 3])
backbone.bn6.weight torch.Size([512])
backbone.bn6.bias torch.Size([512])
backbone.conv7.weight torch.Size([512, 512, 3, 3])
backbone.bn7.weight torch.Size([512])
backbone.bn7.bias torch.Size([512])
classi

Load the pretrained model and transfer the weights.

In [8]:
from torchvision.models import vgg11_bn
torch_vgg = vgg11_bn(pretrained=True)
state_dict = torch_vgg.state_dict()
model.load_state_dict(state_dict)

Downloading: "https://download.pytorch.org/models/vgg11_bn-6002323d.pth" to /root/.cache/torch/hub/checkpoints/vgg11_bn-6002323d.pth
100%|██████████| 507M/507M [00:03<00:00, 159MB/s]  


RuntimeError: Error(s) in loading state_dict for VGG:
	Missing key(s) in state_dict: "backbone.conv0.weight", "backbone.bn0.weight", "backbone.bn0.bias", "backbone.bn0.running_mean", "backbone.bn0.running_var", "backbone.conv1.weight", "backbone.bn1.weight", "backbone.bn1.bias", "backbone.bn1.running_mean", "backbone.bn1.running_var", "backbone.conv2.weight", "backbone.bn2.weight", "backbone.bn2.bias", "backbone.bn2.running_mean", "backbone.bn2.running_var", "backbone.conv3.weight", "backbone.bn3.weight", "backbone.bn3.bias", "backbone.bn3.running_mean", "backbone.bn3.running_var", "backbone.conv4.weight", "backbone.bn4.weight", "backbone.bn4.bias", "backbone.bn4.running_mean", "backbone.bn4.running_var", "backbone.conv5.weight", "backbone.bn5.weight", "backbone.bn5.bias", "backbone.bn5.running_mean", "backbone.bn5.running_var", "backbone.conv6.weight", "backbone.bn6.weight", "backbone.bn6.bias", "backbone.bn6.running_mean", "backbone.bn6.running_var", "backbone.conv7.weight", "backbone.bn7.weight", "backbone.bn7.bias", "backbone.bn7.running_mean", "backbone.bn7.running_var", "classifier.weight", "classifier.bias". 
	Unexpected key(s) in state_dict: "features.0.weight", "features.0.bias", "features.1.weight", "features.1.bias", "features.1.running_mean", "features.1.running_var", "features.1.num_batches_tracked", "features.4.weight", "features.4.bias", "features.5.weight", "features.5.bias", "features.5.running_mean", "features.5.running_var", "features.5.num_batches_tracked", "features.8.weight", "features.8.bias", "features.9.weight", "features.9.bias", "features.9.running_mean", "features.9.running_var", "features.9.num_batches_tracked", "features.11.weight", "features.11.bias", "features.12.weight", "features.12.bias", "features.12.running_mean", "features.12.running_var", "features.12.num_batches_tracked", "features.15.weight", "features.15.bias", "features.16.weight", "features.16.bias", "features.16.running_mean", "features.16.running_var", "features.16.num_batches_tracked", "features.18.weight", "features.18.bias", "features.19.weight", "features.19.bias", "features.19.running_mean", "features.19.running_var", "features.19.num_batches_tracked", "features.22.weight", "features.22.bias", "features.23.weight", "features.23.bias", "features.23.running_mean", "features.23.running_var", "features.23.num_batches_tracked", "features.25.weight", "features.25.bias", "features.26.weight", "features.26.bias", "features.26.running_mean", "features.26.running_var", "features.26.num_batches_tracked", "classifier.0.weight", "classifier.0.bias", "classifier.3.weight", "classifier.3.bias", "classifier.6.weight", "classifier.6.bias". 

In [None]:
VGG11_BN_Weights.__dict__

In [None]:
torch_vgg

In [None]:
for module in torch_vgg.modules():
    print(module)

In [10]:
torch_vgg

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): MaxPool2d(ke

In [16]:
for module in torch_vgg.named_modules():
    print(module)

('', VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): MaxPool