# Pruning Example
This notebook demonstrates how to use the pruning methods from this tool to compress a model. 

The example uses the MNIST dataset and a simple CNN model. The model is trained and then pruned using the methods in this tool. The pruned model is then evaluated on the test set to see how well it performs.

In [1]:
#Define pytorch model
import os
import sys
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import torch.optim as optim
import transformers
import importlib
import inspect
import torchvision.datasets as datasets
import torchvision.transforms as transforms


# Add thesis package to path
sys.path.append("../")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set device
no_cuda = False
use_cuda = not no_cuda and torch.cuda.is_available()
print(f"Using cuda: {use_cuda}")

device = torch.device("cuda" if use_cuda else "cpu")

Using cuda: False


  return torch._C._cuda_getDeviceCount() > 0


### Load model

In [3]:
model_state = "../models/mnist.pt"
model_class = "models.mnist"


# Import the module class
module = importlib.import_module(model_class)

# Get all classes in the module
classes = [
    obj[1] for obj in inspect.getmembers(module, inspect.isclass)
]

# Import the classes that are Modules
for cls in classes:
    if issubclass(cls, torch.nn.Module):
        # Add the class to this package's variables
        globals()[cls.__name__] = cls

model = torch.load(model_state, map_location=torch.device(device))

Prepare data loaders for training and testing.

In [14]:
# Load MNIST dataset
batch_size = 8
test_batch_size = 1000
use_cuda = False


kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
mnist_transform = transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True, transform=mnist_transform,),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, download=True, transform=mnist_transform,),
    batch_size=test_batch_size, shuffle=True, **kwargs),


Define train and test methods.

In [19]:
epochs = 1
lr = 0.01
momentum = 0.5
log_interval = 10
save_model = True
save_model_path = "../save/mnist_cnn.pt"

In [21]:
#Define the training function
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

#Define the test function
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader)

    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [22]:
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

for epoch in range(1, epochs + 1):
    # train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

if (save_model):
    torch.save(model,save_model_path)




AttributeError: 'tuple' object has no attribute 'dataset'

## Pruning

In [None]:
# Test model performance before pruning
print("Before pruning")
test(model, device, test_loader)

# Define the parameters to prune
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.95,
)

# Test model performance after pruning
print("After pruning")
test(model, device, test_loader)

# Print number of parameters in model
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')


# Print number of parameters pruned
total_pruned_params = sum(p.numel() for p in model.parameters() if hasattr(p, 'mask'))
print(f'{total_pruned_params:,} parameters pruned ({100 * total_pruned_params / total_params:.2f}% pruned)')

# Print number of parameters remaining
total_unpruned_params = sum(p.numel() for p in model.parameters() if not hasattr(p, 'mask'))
print(f'{total_unpruned_params:,} parameters unpruned ({100 * total_unpruned_params / total_params:.2f}% unpruned)')


Before pruning
Test set: Average loss: 0.0356, Accuracy: 9882/10000 (99%)
After pruning
Test set: Average loss: 0.3147, Accuracy: 9496/10000 (95%)
431,080 total parameters.
0 parameters pruned (0.00% pruned)
431,080 parameters unpruned (100.00% unpruned)
