# Pruning Example
This notebook demonstrates how to use the pruning methods from this tool to compress a model. 

The example uses the MNIST dataset and a simple CNN model. The model is trained and then pruned using the methods in this tool. The pruned model is then evaluated on the test set to see how well it performs.

In [1]:
#Define pytorch model
import os
import sys
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import torch.optim as optim
import transformers

### Model Architecture
A model architecture needs to be defined. This can be an existing model, from e.g. Hugginface, or a custom created one.
Here we create a basic CNN model as an example.

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


Set up the basic components for the training and testing of the model.

In [6]:
# Set device
no_cuda = False
use_cuda = not no_cuda and torch.cuda.is_available()
print(f"Using cuda: {use_cuda}")

device = torch.device("cuda" if use_cuda else "cpu")

# Define model, optimizer, and criterion
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.CrossEntropyLoss()

Using cuda: False


Set basic parameters for the training and testing of the model.

In [13]:
# Training settings
batch_size = 64
test_batch_size = 1000
epochs = 1
lr = 0.01
momentum = 0.5
log_interval = 10
save_model = True
save_model_path = "./save/mnist_cnn.pt"

In [11]:
# Method to load data
def get_train_test_split(batch_size, test_batch_size):
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=test_batch_size, shuffle=True, **kwargs)
        
    return train_loader,test_loader

In [None]:
#Define the training function
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

#Define the test function
def test(model, device, test_loader, criterion):
    # Set model to evaluation mode
    model.eval()

    # Initialize test loss and accuracy
    test_loss = 0.0
    test_acc = 0.0

    # Disable gradients (to save memory)
    with torch.no_grad():
        # Loop over batches of test data
        for inputs, labels in test_loader:
            # Move data to device
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Update test loss and accuracy
            test_loss += loss.item()
            test_acc += (outputs.argmax(dim=1) == labels).float().mean()

    
    return test_loss, test_acc

In [12]:
#Define the main function
def main():

    train_loader, test_loader = get_train_test_split(batch_size, test_batch_size)

    # Load model from file
    model = None
    if os.path.exists(save_model_path):
        model = torch.load(save_model_path).to(device)
    # If model is not loaded from file, create a new model
    if model is None:
        model = Net().to(device)

    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

    # Use the profiler to profile the model's execution
    with torch.autograd.profiler.profile() as prof:
        for epoch in range(1, epochs + 1):
            train(model, device, train_loader, optimizer, epoch)
            test(model, device, test_loader)

    # Print the profiler results
    # print(prof.key_averages().table(sort_by="cpu_time_total"))

    if (save_model):
        torch.save(model,save_model_path)




## Pruning

In [31]:

test_loader = get_train_test_split(batch_size, test_batch_size)[1]


# Test model performance before pruning
print("Before pruning")
test(model, device, test_loader)

# Define the parameters to prune
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.95,
)

# Test model performance after pruning
print("After pruning")
test(model, device, test_loader)

# Print number of parameters in model
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')


# Print number of parameters pruned
total_pruned_params = sum(p.numel() for p in model.parameters() if hasattr(p, 'mask'))
print(f'{total_pruned_params:,} parameters pruned ({100 * total_pruned_params / total_params:.2f}% pruned)')

# Print number of parameters remaining
total_unpruned_params = sum(p.numel() for p in model.parameters() if not hasattr(p, 'mask'))
print(f'{total_unpruned_params:,} parameters unpruned ({100 * total_unpruned_params / total_params:.2f}% unpruned)')


Before pruning
Test set: Average loss: 0.0356, Accuracy: 9882/10000 (99%)
After pruning
Test set: Average loss: 0.3147, Accuracy: 9496/10000 (95%)
431,080 total parameters.
0 parameters pruned (0.00% pruned)
431,080 parameters unpruned (100.00% unpruned)


## NLP Use Case