# Pruning Example
This notebook demonstrates how to use the pruning methods from this tool to compress a model. 

The example uses the MNIST dataset and a simple CNN model. The model is trained and then pruned using the methods in this tool. The pruned model is then evaluated on the test set to see how well it performs.

## Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
#Define pytorch model
import os
import sys
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import torch.optim as optim
import transformers
import importlib
import inspect
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# Add thesis package to path
sys.path.append("../")

import src.general as general

In [3]:
# Set device
no_cuda = False
use_cuda = not no_cuda and torch.cuda.is_available()
print(f"Using cuda: {use_cuda}")

device = torch.device("cuda" if use_cuda else "cpu")

Using cuda: False


### Load model

In [4]:
model_state = "../models/mnist.pt"
model_class = "models.mnist"

# Import the module classes
module = importlib.import_module(model_class)
classes = general.get_module_classes(module)
for cls in classes:
    globals()[cls.__name__] = cls

# Get device
device = general.get_device()

# Load the model
model = torch.load(model_state, map_location=torch.device(device))

Using cuda: False


Prepare data loaders for training and testing.

In [5]:
# Load MNIST dataset
batch_size = 64
test_batch_size = 64
use_cuda = False


kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
mnist_transform = transform=transforms.Compose([
                           transforms.ToTensor()
                       ])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True, transform=mnist_transform,),
    batch_size=batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, download=True, transform=mnist_transform,),
    batch_size=test_batch_size, shuffle=True, **kwargs)

Define train and test methods.

In [6]:
# Hyperparameters
epochs = 3
lr = 0.01
momentum = 0.5
log_interval = 100

In [9]:
from models.mnist import MnistModel
from src.general import train, test

model = MnistModel().to(device)

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
criterion = F.nll_loss

for epoch in range(1, epochs + 1):
    train(model, device, train_loader, criterion, optimizer)
    test(model, device, test_loader, criterion)

[autoreload of src.general failed: Traceback (most recent call last):
  File "/Users/abel/miniconda3/envs/torch-gpu/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/Users/abel/miniconda3/envs/torch-gpu/lib/python3.8/site-packages/IPython/extensions/autoreload.py", line 394, in superreload
    module = reload(module)
  File "/Users/abel/miniconda3/envs/torch-gpu/lib/python3.8/imp.py", line 314, in reload
    return importlib.reload(module)
  File "/Users/abel/miniconda3/envs/torch-gpu/lib/python3.8/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 604, in _exec
  File "<frozen importlib._bootstrap_external>", line 839, in exec_module
  File "<frozen importlib._bootstrap_external>", line 976, in get_code
  File "<frozen importlib._bootstrap_external>", line 906, in source_to_code
  File "<frozen importlib._bootstrap>", line 219, 

Average loss = 0.0914
Elapsed time = 20909.89 milliseconds (22.29 per batch, 0.70 per data point)


Test: 100%|██████████| 157/157 [00:01<00:00, 135.83it/s]


Average loss = 0.1821
Elapsed time = 1156.89 milliseconds (7.37 per batch, 0.46 per data point)


Train: 100%|██████████| 938/938 [00:20<00:00, 45.86it/s]


Average loss = 0.1091
Elapsed time = 20454.45 milliseconds (21.81 per batch, 0.68 per data point)


Test: 100%|██████████| 157/157 [00:01<00:00, 135.84it/s]


Average loss = 0.1095
Elapsed time = 1156.83 milliseconds (7.37 per batch, 0.46 per data point)


Train: 100%|██████████| 938/938 [00:20<00:00, 45.08it/s]


Average loss = 0.0772
Elapsed time = 20807.45 milliseconds (22.18 per batch, 0.69 per data point)


Test: 100%|██████████| 157/157 [00:01<00:00, 145.91it/s]

Average loss = 0.0785
Elapsed time = 1077.15 milliseconds (6.86 per batch, 0.43 per data point)





In [None]:
save_model = True
save_model_path = "../models/mnist.pt"

if save_model:
    torch.save(model, save_model_path)

## Pruning

In [None]:
# Test model performance before pruning
print("Before pruning")
test(model, device, test_loader)

# Define the parameters to prune
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.95,
)

# Test model performance after pruning
print("After pruning")
test(model, device, test_loader)

# Print number of parameters in model
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')


# Print number of parameters pruned
total_pruned_params = sum(p.numel() for p in model.parameters() if hasattr(p, 'mask'))
print(f'{total_pruned_params:,} parameters pruned ({100 * total_pruned_params / total_params:.2f}% pruned)')

# Print number of parameters remaining
total_unpruned_params = sum(p.numel() for p in model.parameters() if not hasattr(p, 'mask'))
print(f'{total_unpruned_params:,} parameters unpruned ({100 * total_unpruned_params / total_params:.2f}% unpruned)')


Before pruning


TypeError: test() missing 2 required positional arguments: 'criterion' and 'epoch'