Small experiment to play with and compare some different pruning methods

In [1]:
from torchvision.models import resnet18, ResNet18_Weights
import torch
from torch import nn
from functools import partial
import torch.utils.benchmark as benchmark
import torch.nn.utils.prune as prune


device = torch.device("cuda")
base_model_init = partial(resnet18, weights=ResNet18_Weights.IMAGENET1K_V1)

control_model = base_model_init()
control_model.to(device)

models = {"control": control_model}


Pytorch's built in pruning support.
Note that even their "structured" pruning still just zeros parameters, and doesn't actually remove the zeroed dimensions, so it won't have any actual effect on the speed or size of the model, just sparsifies it. We'd have to manually remove the zeroed elements to truly perform structural pruning and get the benefits.

In [None]:
### Built in pytorch pruning

# local unstructured
torch_local_unstruct_model = base_model_init()
torch_local_unstruct_model.to(device)

for _, module in torch_local_unstruct_model.named_modules():
    if isinstance(module, nn.Conv2d):
        prune.random_unstructured(module, name="weight", amount=0.3)
        prune.remove(module, 'weight')

models["torch_local_unstruct"] = torch_local_unstruct_model

# local structured
torch_local_struct_model = base_model_init()
torch_local_struct_model.to(device)

for _, module in torch_local_struct_model.named_modules():
    if isinstance(module, nn.Conv2d):
        prune.random_structured(module, name="weight", amount=0.3, dim=0)
        prune.remove(module, 'weight')

models["torch_local_struct"] = torch_local_struct_model


In [3]:
def compare_modelsize(models):

    for model_name, model in models.items():
        param_size = 0
        for param in model.parameters():
            param_size += param.nelement() * param.element_size()
        buffer_size = 0
        for buffer in model.buffers():
            buffer_size += buffer.nelement() * buffer.element_size()

        size_all_mb = (param_size + buffer_size) / 1024**2
        print(f"\"{model_name}\" size: {size_all_mb:.3f}MB")

In [4]:
def compare_benchmarks(models, batches=10, batch_size=16, channels=3, input_size=(224, 224)):
    results = []
    torch.manual_seed(0)

    for _ in range(batches):
        x = torch.randn(batch_size, channels, *input_size).to(device)

        label = "Model comparison"
        sub_label = "Inference"

        for modelname, model in models.items():
            results.append(benchmark.Timer(
                stmt='model(x)',
                globals={'x': x, 'model': model},
                num_threads=1,
                label=label,
                sub_label=sub_label,
                description=modelname
            ).blocked_autorange(min_run_time=1))

    compare = benchmark.Compare(results)
    compare.print()

In [5]:
compare_benchmarks(models)
compare_modelsize(models)

[---------------------------- Model comparison ---------------------------]
                 |  control  |  torch_local_unstruct  |  torch_local_struct
1 threads: ----------------------------------------------------------------
      Inference  |    8.6    |          8.6           |         8.6        

Times are in milliseconds (ms).

"control" size: 44.629MB
"torch_local_unstruct" size: 44.629MB
"torch_local_struct" size: 44.629MB


As guessed, this doesn't actually provide any improved speed or size to the model, since the sparsified params aren't actually being sliced. Was still a useful exercise to get familiar with pytorch built in pruning methods and some benchmarking tools

Let's try a third party lib. There isn't much in the way of fully auto solutions, I guess the most common use case is using something like pytorch's sparsification and then manually adding a step during training (or afterwards) to recreate the model with unneccessary nodes removed, and copy across weights?

Torch_pruning seems quite full-featured, and is still actively supported

In [6]:
import torch_pruning as tp

high_level_tp_model = base_model_init()
high_level_tp_model.to(device)

example_inputs = torch.randn(1, 3, 224, 224).to(device)

imp = tp.importance.GroupMagnitudeImportance(p=2) 
ignored_layers = []
for m in high_level_tp_model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m)

pruner = tp.pruner.BasePruner(
    high_level_tp_model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    # pruning_ratio_dict = {model.conv1: 0.2, model.layer2: 0.8}, # customized pruning ratios for layers or blocks
    ignored_layers=ignored_layers,
    round_to=8, # It's recommended to round dims/channels to 4x or 8x for acceleration. Please see: https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html
    isomorphic=True, # enable isomorphic pruning to improve global ranking
    global_pruning=True, # global pruning
)

base_macs, base_nparams = tp.utils.count_ops_and_params(high_level_tp_model, example_inputs)
tp.utils.print_tool.before_pruning(high_level_tp_model) # or print(model)
pruner.step()
tp.utils.print_tool.after_pruning(high_level_tp_model) # or print(model), this util will show the difference before and after pruning
macs, nparams = tp.utils.count_ops_and_params(high_level_tp_model, example_inputs)
print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")

models["high_level_tp"] = high_level_tp_model


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) => (conv1): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) => (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) => (conv1): Conv2d(32, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) => (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1

In [7]:
compare_benchmarks(models)
compare_modelsize(models)

[------------------------------------- Model comparison ------------------------------------]
                 |  control  |  torch_local_unstruct  |  torch_local_struct  |  high_level_tp
1 threads: ----------------------------------------------------------------------------------
      Inference  |    8.6    |          8.6           |         8.6          |       3.4     

Times are in milliseconds (ms).

"control" size: 44.629MB
"torch_local_unstruct" size: 44.629MB
"torch_local_struct" size: 44.629MB
"high_level_tp" size: 11.544MB


Greatly improved model speed and size. Worth noting that when I tried this lib with the torchvision prebuilt ViT, it broke - possibly an issue in the weights of the transformer/attention blocks - if we add those to the ignore list it may work (although then the gains we'll get if we can't prune attention blocks will be a lot less)