# Pruning Example
This notebook demonstrates how to use the pruning methods from this tool to compress a model. 

The example uses the MNIST dataset and a simple CNN model. The model is trained and then pruned using the methods in this tool. The pruned model is then evaluated on the test set to see how well it performs.

## Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [17]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import transformers
import importlib
import inspect
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# Add thesis package to path
sys.path.append("../")

import src.general as general
import src.compression.pruning as pruning
import src.metrics as metrics

In [16]:
model_state = "../models/mnist.pt"
model_class = "models.mnist"

# Import the module classes
module = importlib.import_module(model_class)
classes = general.get_module_classes(module)
for cls in classes:
    globals()[cls.__name__] = cls

# Get device
device = general.get_device()

# Load the model
model = torch.load(model_state, map_location=torch.device(device))

Using cuda: False


In [5]:
# Load MNIST dataset
batch_size = 64
test_batch_size = 64

kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {}
mnist_transform = transform=transforms.Compose([
                           transforms.ToTensor()
                       ])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True, transform=mnist_transform,),
    batch_size=batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, download=True, transform=mnist_transform,),
    batch_size=test_batch_size, shuffle=True, **kwargs)

Define train and test methods.

In [6]:
# Hyperparameters
epochs = 3
lr = 0.01
momentum = 0.5
log_interval = 100

In [10]:
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
criterion = F.nll_loss

for epoch in range(1, epochs + 1):
    general.train(model, device, train_loader, criterion, optimizer, metric=metrics.accuracy)
    general.test(model, device, test_loader, criterion, metric=metrics.accuracy)

Train: 100%|██████████| 938/938 [00:20<00:00, 46.03it/s]


Average loss = 0.3792
Elapsed time = 20381.11 milliseconds (21.73 per batch, 0.68 per data point)


Test: 100%|██████████| 157/157 [00:01<00:00, 148.38it/s]


Average loss = 0.1961
Accuracy = 0.9396
Elapsed time = 1059.20 milliseconds (6.75 per batch, 0.42 per data point)


Train: 100%|██████████| 938/938 [00:20<00:00, 46.46it/s]


Average loss = 0.0631
Elapsed time = 20191.55 milliseconds (21.53 per batch, 0.67 per data point)


Test: 100%|██████████| 157/157 [00:01<00:00, 146.45it/s]


Average loss = 0.0986
Accuracy = 0.9702
Elapsed time = 1073.09 milliseconds (6.83 per batch, 0.43 per data point)


Train: 100%|██████████| 938/938 [00:20<00:00, 46.41it/s]


Average loss = 0.0688
Elapsed time = 20213.10 milliseconds (21.55 per batch, 0.67 per data point)


Test: 100%|██████████| 157/157 [00:01<00:00, 147.40it/s]

Average loss = 0.0678
Accuracy = 0.9792
Elapsed time = 1066.15 milliseconds (6.79 per batch, 0.42 per data point)





In [11]:
save_model = True
save_model_path = "../models/mnist.pt"

if save_model:
    torch.save(model, save_model_path)

## Pruning
Pruning is a process of reducing the size of a machine learning model by removing unimportant weights and neurons. Pruning can be used to reduce the number of parameters in a model, thereby reducing the memory footprint and the computational complexity of the model. 

Pruning is typically done in two ways: structured pruning, which involves selectively removing a larger part of the network such as a layer or a channel, and unstructured pruning, which involves removing individual weights or neurons.

### Unstructured Pruning
Here we prune individual weights or neurons.

In [None]:
# Test model performance before pruning
print("Before pruning")
test(model, device, test_loader)

# Define the parameters to prune
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.95,
)

# Test model performance after pruning
print("After pruning")
test(model, device, test_loader)

# Print number of parameters in model
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')


# Print number of parameters pruned
total_pruned_params = sum(p.numel() for p in model.parameters() if hasattr(p, 'mask'))
print(f'{total_pruned_params:,} parameters pruned ({100 * total_pruned_params / total_params:.2f}% pruned)')

# Print number of parameters remaining
total_unpruned_params = sum(p.numel() for p in model.parameters() if not hasattr(p, 'mask'))
print(f'{total_unpruned_params:,} parameters unpruned ({100 * total_unpruned_params / total_params:.2f}% unpruned)')


Before pruning


TypeError: test() missing 2 required positional arguments: 'criterion' and 'epoch'

### Structured Pruning