## Using SageMaker Debugger and SageMaker Experiments for iterative model pruning

This notebook demonstrates how we can use SageMaker Debugger and SageMaker Experiments to perform iterative model pruning. Let's start first with a quick introduction into model pruning.

State of the art deep learning models consist of millions of parameters and are trained on very large datasets. For transfer learning we take a pre-trained model and fine-tune it on a new and typically much smaller dataset. The new dataset may even consist of different classes, so the model is basically learning a new task. This process allows us to quickly achieve state of the art results without having to design and train our own model from scratch. However it may happen that a much smaller and simpler model would also perform well on our dataset. With model pruning we identify the importance of weights during training and remove the weights that are contributing negligibly to the learning process. We can do this in an iterative way where we remove let's say 5% of weights in each iteration. 

We use SageMaker Debugger to get weights, activation outputs and gradients during training. These tensors are used to compute the importance of weights. We will use SageMaker Experiments to keep track of each pruning iteration: if we prune too much we may degrade model accuracy, so we will monitor number of parameters versus validation accuracy. 


In [None]:
! pip install sagemaker
! pip install sagemaker-experiments
! pip install torchsummary

### Load and save AlexNet model

First we get a pre-trained AlexNet model from PyTorch model zoo. 

In [None]:
import torch
from torchvision import models
from torch import nn
from pytorch_iterative_model_pruning import model_alexnet

model = models.alexnet(pretrained=True)

The model is trained on CIFAR10 so we set the number of output classes to 10.

In [None]:
model.classifier[6] = nn.Linear(4096, 10)

Next we store the model definition and weights in an output file.

In [None]:
checkpoint = {'model': model,
              'state_dict': model.state_dict()}

torch.save(checkpoint, 'src/checkpoint_model_pruned')     

The following code cell creates a SageMaker experiment:

In [None]:
import boto3
from smexperiments.experiment import Experiment

sagemaker_boto_client = boto3.client("sagemaker")

mnist_autoencoder = Experiment.create(
    experiment_name="model-pruning-experiment", 
    description="Iterative model pruning of AlexNet trained on CIFAR10", 
    sagemaker_boto_client=sagemaker_boto_client)

The following code cell defines a list of tensor names that are considered for pruning. The list contains all convolutional layers and their biases. It also includes the fully-connected layers of the classifier.

In [None]:
activation_outputs = model_alexnet.activation_outputs
gradients = model_alexnet.gradients
weights = model_alexnet.weights
biases = model_alexnet.biases
classifier_weights = model_alexnet.classifier_weights
classifier_biases = model_alexnet.classifier_biases

In [None]:
import sagemaker 

sagemaker_boto_client = boto3.client("sagemaker")

experiment_config = { "ExperimentName": "model-pruning-experiment", 
                      "TrialName": None,
                      "TrialComponentDisplayName": "Training"}

sagemaker_session = sagemaker.Session()
BUCKET_NAME = sagemaker_session.default_bucket()
LOCATION_IN_BUCKET = 'smdebug-model-pruning-example'

### Iterative model pruning: step by step

Before we jump into the code for running the iterative model pruning we will walk through the code step by step. First we create a new trial for each pruning iteration. That allows us to track our training jobs and see which models have the lowest number of parameters and best accuracy. We use the `smexperiments` library to create a trial within an experiment named `model-pruning-experiment`.

```python
from smexperiments.trial import Trial
from smexperiments.tracker import Tracker
import boto3

sagemaker_boto_client = boto3.client("sagemaker")

trial = Trial.create(
        experiment_name="model-pruning-experiment",
        sagemaker_boto_client=sagemaker_boto_client
    )
```

Next we define the `experiment_config` which is a dictionary that will be passed to the SageMaker training. It associates the training job with a trial and an experiment. If we don't specify an `experiment_config` the training job will appear in SageMaker Experiments under `Unassigned trial components`

```python 
experiment_config = {
                        "ExperimentName": "model-pruning-experiment", 
                        "TrialName": trial.trial_name,
                        "TrialComponentDisplayName": "Training",
                    }
```

We use the SageMaker default bucket, to store the tensors emitted by Debugger. 

``` python 

import sagemaker 

sagemaker_session = sagemaker.Session()
BUCKET_NAME = sagemaker_session.default_bucket()
LOCATION_IN_BUCKET = 'model-pruning-example'
TRIAL_NAME = trial.trial_name

s3_bucket_for_tensors = 's3://{BUCKET_NAME}/{LOCATION_IN_BUCKET}/{TRIAL_NAME}'.format(BUCKET_NAME=BUCKET_NAME, LOCATION_IN_BUCKET=LOCATION_IN_BUCKET, TRIAL_NAME=TRIAL_NAME)

```
We create a debugger hook configuration to define a custom collection of tensors to be emitted. The custom collection contains all weights and biases of the model. It also includes individual layer outputs and their gradients which will be used to compute filter ranks. Tensors are saved every 100th iteration where an iteration represents one forward and backward pass.

```python
debugger_hook_config = DebuggerHookConfig(
              s3_output_path=s3_bucket_for_tensors,  
              collection_configs=[ 
                  CollectionConfig(
                        name="custom_collection",
                        parameters={ "include_regex": ".*output_0|.*weight|.*bias",
                                 "save_interval": "100" })])
 ```                                    
Now we define the SageMaker PyTorch Estimator. We will train the model on an `ml.p2.xlarge` instance. The model definition plus training is defined in the entry_point file `train.py`. 

```python
from sagemaker.pytorch import PyTorch

estimator = PyTorch(role=sagemaker.get_execution_role(),
                  train_instance_count=1,
                  train_instance_type='ml.p2.xlarge',
                  train_volume_size=400,
                  source_dir='src',
                  entry_point='train.py',
                  framework_version='1.3.1',
                  py_version='py3',
                  metric_definitions=[ {'Name':'train:loss', 'Regex':'loss:(.*?)'}, {'Name':'eval:acc', 'Regex':'acc:(.*?)'} ],
                  enable_sagemaker_metrics=True,
                  hyperparameters = {'epochs': 10},
                  debugger_hook_config=debugger_hook_config
        )
```
Once we have the estimator object we can call `fit` which creates a `ml.p2.xlarge` instance on which it starts the training.

```python
estimator.fit(experiment_config=experiment_config)
```

Once the training job has finished, we will retrieve its tensors, such as gradients, weights and biases. We use the `smdebug` library which provides functions to read and filter tensors. First we create a trial that is reading the tensors from S3.

```python
from smdebug.trials import create_trial

smdebug_trial = create_trial(s3_bucket_for_tensors)
```


To access tensor values, we only need to call `smdebug_trial.tensor()`. For instance to get the value of the first fully connected layer at step 0 we run  `smdebug_trial.tensor('AlexNet_classifier.1.weight').value(0, mode=modes.TRAIN)`. Next we compute a filter rank for the convolutions. 

To recap: a filter is a collection of kernels (one kernel for every single input channel) and a filter produces one feature map (output channel). In the image below the convolution creates 64 feature maps (output channels) and uses a kernel of 5x5. By pruning a filter, an entire feature map will be removed. So in the example image below the number of feature maps (output channels) would shrink to 63 and the number of learnable parameters (weights) would be reduced by 1x5x5.

![](images/convolution.png) 


In this noteook we compute filter ranks as described in the article ["Pruning Convolutional Neural Networks for Resource Efficient Inference"](https://arxiv.org/pdf/1611.06440.pdf) 
In the following code we retrieve activation outputs and gradients and compute the 1st order Taylor series that is used to measure the filter rank. 

```python
    filters = {}
    for activation_output_name, gradient_name in zip(activation_outputs, gradients):
        for step in smdebug_trial.steps(mode=modes.TRAIN):
            activation_output = smdebug_trial.tensor(activation_output_name).value(step, mode=modes.TRAIN)
            gradient = smdebug_trial.tensor(gradient_name).value(step, mode=modes.TRAIN)
            rank = activation_output * gradient
            rank = np.mean(rank, axis=(0,2,3))

            if activation_output_name not in filters:
                filters[activation_output_name] = 0
            filters[activation_output_name] += rank
```

Next we normalize the filters:
```python
    rank = np.abs(filters[activation_output_name])
    rank = rank / np.sqrt(np.sum(rank * rank))
    filters[activation_output_name] = rank
```

We create a list of filters, sort it by rank and retrieve the smallest values:

```python
filters_list = []
for layer_name in sorted(filters.keys()):
    for channel in range(filters[layer_name].shape[0]): 
        filters_list.append((layer_name, channel, filters[layer_name][channel], ))

filters_list.sort(key = lambda x: x[2])
filters_list = filters_list[:100]
print("The 100 smallest filters", filters_list)
```
Next we prune the model, where we remove filters and their corresponding weights. The new model definition and weights are saved under `src` and will be used by the next training job.

```python
model = model_alexnet.prune(model, 
                    activation_outputs, 
                    weights, 
                    biases, 
                    classifier_weights, 
                    classifier_biases, 
                    filters_dict, 
                    trial)

checkpoint = {'model': model,
              'state_dict': model.state_dict()}

torch.save(checkpoint, 'src/checkpoint_model_pruned')  
del model
```


The overall workflow looks like the following:
 ![](images/workflow.png)

### Run iterative model pruning

After having gone through the code step by step, we are ready to run the full worfklow. The following cell runs 10 pruning iterations: in each iteration of the pruning a new SageMaker training job is started, where it emits gradients and activation outputs to Amazon S3. Once the job has finished, filter ranks are computed and the 100 smallest filters are removed.



In [None]:
import numpy as np
from sagemaker.pytorch import PyTorch
from smexperiments.trial import Trial
from smdebug.trials import create_trial
from smdebug import modes
from sagemaker.debugger import DebuggerHookConfig, CollectionConfig
from torchsummary import summary

# start iterative pruning
for pruning_step in range(10):
    
    #create new trial for this pruning step
    smexperiments_trial = Trial.create(
        experiment_name="model-pruning-experiment",
        sagemaker_boto_client=sagemaker_boto_client
    )
    experiment_config["TrialName"] = smexperiments_trial.trial_name

    # s3 path where tensors will be stored
    s3_bucket_for_tensors = 's3://{BUCKET_NAME}/{LOCATION_IN_BUCKET}/{TRIAL_NAME}'.format(BUCKET_NAME=BUCKET_NAME, LOCATION_IN_BUCKET=LOCATION_IN_BUCKET, TRIAL_NAME=smexperiments_trial.trial_name)

    print("Created new trial", smexperiments_trial.trial_name, "for pruning step", pruning_step)
    
    #debugger hook configuration for custom collection
    debugger_hook_config = DebuggerHookConfig(
                  s3_output_path=s3_bucket_for_tensors,  
                  collection_configs=[ 
                      CollectionConfig(
                            name="custom_collection",
                            parameters={ "include_regex": ".*output|.*weight|.*bias",
                                         "save_interval": "100" })])
        
    # train the models
    estimator = PyTorch(role=sagemaker.get_execution_role(),
                  train_instance_count=1,
                  train_instance_type='ml.p2.xlarge',
                  train_volume_size=400,
                  source_dir='src',
                  entry_point='train.py',
                  framework_version='1.3.1',
                  py_version='py3',
                  metric_definitions=[ {'Name':'train:loss', 'Regex':'loss:(.*?)'}, {'Name':'eval:acc', 'Regex':'acc:(.*?)'} ],
                  enable_sagemaker_metrics=True,
                  hyperparameters = {'epochs': 10},
                  debugger_hook_config = debugger_hook_config
        )
    
    #start training job
    estimator.fit(experiment_config=experiment_config)

    print("Training job", estimator.latest_training_job.name , "finished. Read tensors from ", s3_bucket_for_tensors)
    
    # read tensors
    path = estimator.latest_job_debugger_artifacts_path()
    smdebug_trial = create_trial(path)
    
    # compute filter ranks
    filters = {}
    for activation_output_name, gradient_name in zip(activation_outputs, gradients):
        for step in smdebug_trial.steps(mode=modes.TRAIN):
            activation_output = smdebug_trial.tensor(activation_output_name).value(step, mode=modes.TRAIN)
            gradient = smdebug_trial.tensor(gradient_name).value(step, mode=modes.TRAIN)
            rank = activation_output * gradient
            rank = np.mean(rank, axis=(0,2,3))

            if activation_output_name not in filters:
                filters[activation_output_name] = 0
            filters[activation_output_name] += rank
        
        #normalize
        rank = np.abs(filters[activation_output_name])
        rank = rank / np.sqrt(np.sum(rank * rank))
        filters[activation_output_name] = rank
        
    # find lowest ranked filters
    filters_list = []
    for layer_name in sorted(filters.keys()):
        for channel in range(filters[layer_name].shape[0]): 
            filters_list.append((layer_name, channel, filters[layer_name][channel], ))

    filters_list.sort(key = lambda x: x[2])
    filters_list = filters_list[:100]
    print("The 100 smallest filters", filters_list)
        
    #load previous model 
    checkpoint = torch.load("src/checkpoint_model_pruned")
    model = checkpoint['model']
    model.load_state_dict(checkpoint['state_dict'])
    
    #print parameters per layer
    print("Pruning iteration:", pruning_step)
    print(summary(model, (3, 64, 64)))
    
    #prune model
    model = model_alexnet.prune(model, 
                        activation_outputs, 
                        weights, 
                        biases, 
                        classifier_weights, 
                        classifier_biases, 
                        filters_list, 
                        smdebug_trial, 
                        step)
    
    print("Saving pruned model")
    
    # save pruned model
    checkpoint = {'model': model,
                  'state_dict': model.state_dict()}
    torch.save(checkpoint, 'src/checkpoint_model_pruned')
    
    #clean up
    del model

As the iterative model pruning is running, we can track and visualize our experiment in SageMakwer Studio. The following image shows the number of parameters versus the model accuracy. The first iteration of the model consisted of 57 million parameters. After 10 iterations the number of parameters was reduced to 18 million, while accuracy increased to 97% and then dropped after the 6th pruning iteration. This means that the best accuracy can be reached if the model has a size of about 20 million parameters.

![](images/results.png) 

### Results
The following animation shows the number of parameters per layer for each pruning iteration. We can see that most of the parameters are pruned in the last convolutional layers. The model starts with 57 million parameters and a size of 218 MB. After 10 iterations it consists of only 18 million parameters and 70 MB. Less parameters means smaller model size, and hence faster training and inference.

![](images/results.gif)