# Structured Pruning of a Fully-Connected PyTorch Model

[Run this tutorial in Google Colab](https://colab.research.google.com/github/sony/model_optimization/blob/main/tutorials/notebooks/mct_features_notebooks/pytorch/example_pytorch_pruning_mnist.ipynb)

Welcome to this tutorial, where we will guide you through the process of training, pruning, and retraining a fully connected neural network model using the PyTorch framework. The tutorial is organized in the following sections:
1. We'll start by installing and importing the nessecry packages.
2. Next, we will construct and train a simple neural network on the MNIST dataset.
2. Following that, we'll introduce model pruning to reduce the model's size while maintaining accuracy.
3. Finally, we'll retrain our pruned model to recover any performance lost due to pruning.

## Installing Pytorch and the Model Compression Toolkit
We begin by setting up our environment by installing PyTorch and the Model Compression Toolkit, then importing them. These installations will allow us to define, train, prune, and retrain our neural network models within this notebook.

In [None]:
!pip install -q torch torchvision
!pip install -q mct-nightly

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import model_compression_toolkit  as mct

## Loading and Preprocessing MNIST Dataset
Let's create a function to retrieve the train and test parts of the MNIST dataset, including preprocessing:

In [None]:
# MNIST Data Loading and Preprocessing
def load_and_preprocess_mnist(batch_size=128, root_path='./data'):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    train_dataset = datasets.MNIST(root=root_path, train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root=root_path, train=False, download=True, transform=transform)

    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

## Creating a Fully-Connected Model
In this section, we create a simple example of a fully connected model to demonstrate the pruning process. It consists of three linear layers with 128, 64, and 10 neurons.

In [None]:
# Define the Fully-Connected Model
class FCModel(nn.Module):
    def __init__(self):
        super(FCModel, self).__init__()
        self.flatten = nn.Flatten()
        self.fc_layers = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.fc_layers(x)
        return logits

## Defining the Training Function

Next, we'll define a function to train our neural network model. This function will handle the training loop, including forward propagation, loss calculation, backpropagation, and updating the model parameters. Additionally, we'll evaluate the model's performance on the validation dataset at the end of each epoch to monitor its accuracy.

In [None]:
def test_model(model, test_loader):
# Evaluate the model
    model.eval()
    total, correct = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

# Training the Dense Model
def train_model(model, train_loader, test_loader, device, epochs=6):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(epochs):
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        accuracy = test_model(model, test_loader)
        print(f'Epoch [{epoch+1}/{epochs}], Test Accuracy: {accuracy:.2f}%')
    return model

## Training the Dense Model
We will now train the dense model using the MNIST dataset.

In [None]:
train_loader, test_loader = load_and_preprocess_mnist()
dense_model = FCModel()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dense_model = train_model(dense_model, train_loader, test_loader, device, epochs=6)

## Dense Model Properties
We will display our model's architecture, including layers, their types, and the number of parameters.
Notably, MCT's structured pruning will target the first two dense layers for pruning, as these layers  have a higher number of channels compared to later layers, offering more opportunities for pruning without affecting accuracy significantly. This reduction can be effectively propagated by adjusting the input channels of subsequent layers.

In [None]:
def display_model_params(model):
    model_params = sum(p.numel() for p in model.state_dict().values())
    for name, module in model.named_modules():
        module_params = sum(p.numel() for p in module.state_dict().values())
        if module_params > 0:
            print(f'{name} number of parameters {module_params}')
    print(f'{model}\nTotal number of parameters {model_params}')
    return model_params

dense_model_params = display_model_params(dense_model)

## Create a Representative Dataset
We are creating a representative dataset to guide our model pruning process for computing importance score for each channel:

In [None]:
# Create a representative dataset
ds_train_as_iter = iter(train_loader)

def representative_data_gen() -> list:
  yield [next(ds_train_as_iter)[0]]

## Pruning the Model
Next,we'll proceed with pruning our trained model to decrease its size, targeting a 50% reduction in the memory footprint of the model's weights. Given that the model's weights utilize the float32 data type, where each parameter occupies 4 bytes, we calculate the memory requirement by multiplying the total number of parameters by 4.

In [None]:
compression_ratio = 0.5
# Define Resource Utilization constraint for pruning. Each float32 parameter requires 4 bytes,
# hence we multiply the total parameter count by 4 to calculate the memory footprint.
target_resource_utilization = mct.core.ResourceUtilization(weights_memory=dense_model_params * 4 * compression_ratio)
# Define a pruning configuration
pruning_config=mct.pruning.PruningConfig(num_score_approximations=1)
# Prune the model
pruned_model, pruning_info = mct.pruning.pytorch_pruning_experimental(model=dense_model, target_resource_utilization=target_resource_utilization, representative_data_gen=representative_data_gen, pruning_config=pruning_config)

### Model after pruning
Let us view the model after the pruning operation and check the accuracy. We can see that pruning process caused a degradation in accuracy.

In [None]:
pruned_model_nparams = display_model_params(pruned_model)
acc_before_retrain = test_model(pruned_model, test_loader)
print(f'Pruned model accuracy before retraining {acc_before_retrain}%')

## Retraining the Pruned Model
After pruning, we often need to retrain the model to recover any lost performance.

In [None]:
pruned_model_retrained = train_model(pruned_model, train_loader, test_loader, device, epochs=6)

## Summary
In this tutorial, we demonstrated the process of training, pruning, and retraining a neural network model using the Model Compression Toolkit. We began by setting up our environment and loading the dataset, followed by building and training a fully connected neural network. We then introduced the concept of model pruning, specifically targeting the first two dense layers to efficiently reduce the model's memory footprint by 50%. After applying structured pruning, we evaluated the pruned model's performance and concluded the tutorial by fine-tuning the pruned model to recover any lost accuracy due to the pruning process. This tutorial provided a hands-on approach to model optimization through pruning, showcasing the balance between model size, performance, and efficiency.

Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
