Small experiment to play with and compare some different pruning methods

In [1]:
from torchvision.models import vit_b_32, ViT_B_32_Weights
import torch
from torch import nn
from functools import partial
import torch.utils.benchmark as benchmark
import torch.nn.utils.prune as prune


device = torch.device("cuda")
base_model_init = partial(vit_b_32, weights=ViT_B_32_Weights.IMAGENET1K_V1)

control_model = base_model_init()
control_model.to(device)

models = {"control": control_model}


Pytorch's built in pruning support.
Note that even their "structured" pruning still just zeros parameters, and doesn't actually remove the zeroed dimensions, so it won't have any actual effect on the speed or size of the model, just sparsifies it. We'd have to manually remove the zeroed elements to truly perform structural pruning and get the benefits.

In [2]:
### Built in pytorch pruning

# local unstructured
torch_local_unstruct_model = base_model_init()
torch_local_unstruct_model.to(device)

for _, module in torch_local_unstruct_model.named_modules():
    if isinstance(module, nn.Conv2d):
        prune.random_unstructured(module, name="weight", amount=0.3)
        prune.remove(module, 'weight')

models["torch_local_unstruct"] = torch_local_unstruct_model

# local structured
torch_local_struct_model = base_model_init()
torch_local_struct_model.to(device)

for _, module in torch_local_struct_model.named_modules():
    if isinstance(module, nn.Conv2d):
        prune.random_structured(module, name="weight", amount=0.3, dim=0)
        prune.remove(module, 'weight')

models["torch_local_struct"] = torch_local_struct_model



In [32]:
def compare_modelsize(models):

    for model_name, model in models.items():
        model = control_model
        param_size = 0
        for param in model.parameters():
            param_size += param.nelement() * param.element_size()
        buffer_size = 0
        for buffer in model.buffers():
            buffer_size += buffer.nelement() * buffer.element_size()

        size_all_mb = (param_size + buffer_size) / 1024**2
        print(f"\"{model_name}\" size: {size_all_mb:.3f}MB")

In [33]:
def compare_benchmarks(models, batches=10, batch_size=16, channels=3, input_size=(224, 224)):
    results = []
    torch.manual_seed(0)

    for _ in range(batches):
        x = torch.randn(batch_size, channels, *input_size).to(device)

        label = "Model comparison"
        sub_label = "Inference"

        for modelname, model in models.items():
            results.append(benchmark.Timer(
                stmt='model(x)',
                globals={'x': x, 'model': model},
                num_threads=1,
                label=label,
                sub_label=sub_label,
                description=modelname
            ).blocked_autorange(min_run_time=1))

    compare = benchmark.Compare(results)
    compare.print()

In [34]:
compare_benchmarks(models)
compare_modelsize(models)
    






[---------------------------------------- Model comparison ---------------------------------------]
                 |  control  |  torch_rand_unstruct  |  torch_local_unstruct  |  torch_local_struct
1 threads: ----------------------------------------------------------------------------------------
      Inference  |    22.0   |          23.8         |          22.1          |         22.1       

Times are in milliseconds (ms).

"control" size: 336.549MB
"torch_rand_unstruct" size: 336.549MB
"torch_local_unstruct" size: 336.549MB
"torch_local_struct" size: 336.549MB


As guessed, this doesn't actually provide any improved speed or size to the model, since the sparsified params aren't actually being sliced. Was still a useful exercise to get familiar with pytorch built in pruning methods and some benchmarking tools