In [None]:
#| include: false
from nbdev.showdoc import *

Neural Network Pruning usually follows one of the next 3 schedules:

![](../../imgs/schedules.png "Schedules")

In fasterai, all those 3 schedules can be applied from the **same** callback. We'll cover each below

In the SparsifyCallback, there are several parameters to 'shape' our pruning schedule:
* `start_sparsity`: the initial sparsity of our model, generally kept at 0 as after initialization, our weights are generally non-zero.
* `end_sparsity`: the target sparsity at the end of the training 
* `start_epoch`: we can decide to start pruning right from the beginning or let it train a bit before removing weights.
* `sched_func`: this is where the general shape of the schedule is specified as it specifies how the sparsity evolves along the training. You can either use a schedule [available](https://docs.fast.ai/callback.schedule.html#Annealing) in fastai our even coming with your own !

---

In [None]:
#| include: false
from fastai.vision.all import *
from fastai.callback.all import *

from fasterai.sparse.all import *

import torch
import torch.nn as nn
import torch.nn.functional as F

import seaborn as sns

sns.set(context='poster', style='white',
        font='sans-serif', font_scale=1, color_codes=True, rc=None)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
path = untar_data(URLs.PETS)

files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64), device=device)

We will first train a network without any pruning, which will serve as a baseline.

In [None]:
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

learn.fit_one_cycle(10)

epoch,train_loss,valid_loss,accuracy,time
0,0.696587,0.398289,0.858593,00:02
1,0.446161,0.596597,0.841001,00:02
2,0.311823,0.447651,0.788904,00:02
3,0.221294,0.325151,0.882273,00:02
4,0.16857,0.210183,0.914073,00:02
5,0.113579,0.247422,0.916103,00:02
6,0.092339,0.213218,0.924899,00:02
7,0.059152,0.182308,0.939107,00:02
8,0.030035,0.190679,0.939784,00:02
9,0.019936,0.190289,0.939107,00:02


## One-Shot Pruning

The simplest way to perform pruning is called One-Shot Pruning. It consists of the following three steps:

1. You first need to train a network
2. You then need to remove some weights (depending on your criteria, needs,...)
3. You fine-tune the remaining weights to recover from the loss of parameters.

With fasterai, this is really easy to do. Let's illustrate it by an example:

In [None]:
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In this case, your network needs to be trained before pruning. This training can be done independently from the pruning callback, or simulated by the `start_epoch` that will delay the pruning process.

You thus only need to create the Callback with the `one_shot` schedule and set the `start_epoch` argument, i.e. how many epochs you want to train your network before pruning it.

In [None]:
sp_cb=SparsifyCallback(sparsity=90, granularity='weight', context='local', criteria=large_final, schedule=one_shot)

Let's start pruningn after 3 epochs and train our model for 6 epochs to have the same total amount of training as before

In [None]:
learn.fit(10, cbs=sp_cb)

Pruning of weight until a sparsity of [90]%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.58421,0.4012,0.826116,00:02
1,0.357988,0.349748,0.841678,00:02
2,0.254447,0.250878,0.895129,00:02
3,0.206469,0.31746,0.872801,00:02
4,0.223192,0.402789,0.841678,00:02
5,0.254638,0.253465,0.886333,00:02


Sparsity at the end of epoch 0: [0.0]%
Sparsity at the end of epoch 1: [0.0]%
Sparsity at the end of epoch 2: [0.0]%
Sparsity at the end of epoch 3: [0.0]%
Sparsity at the end of epoch 4: [90.0]%
Sparsity at the end of epoch 5: [90.0]%


---

## Iterative Pruning

Researchers have come up with a better way to do pruning than pruning all the weigths in once (as in One-Shot Pruning). The idea is to perform several iterations of pruning and fine-tuning and is thus called Iterative Pruning.

1. You first need to train a network
2. You then need to remove a part of the weights weights (depending on your criteria, needs,...)
3. You fine-tune the remaining weights to recover from the loss of parameters.
4. Back to step 2.

In [None]:
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In this case, your network needs to be trained before pruning.

You only need to create the Callback with the `iterative` schedule and set the `start_epoch` argument, i.e. how many epochs you want to train your network before pruning it.

The `iterative` schedules has a `n_steps`parameter, i.e. how many iterations of pruning/fine-tuning you want to perform. To modify its value, we can use the `partial` function like this:

```
iterative = partial(iterative, n_steps=5)
```

In [None]:
sp_cb=SparsifyCallback(sparsity=90, granularity='weight', context='local', criteria=large_final, schedule=iterative)

Let's start pruningn after 3 epochs and train our model for 6 epochs to have the same total amount of training as before

In [None]:
learn.fit(10, cbs=sp_cb)

---

## Gradual Pruning

Here is for example how to implement the [Automated Gradual Pruning](https://arxiv.org/pdf/1710.01878.pdf) schedule.

In [None]:
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In [None]:
sp_cb=SparsifyCallback(sparsity=90, granularity='weight', context='local', criteria=large_final, schedule=agp)

Let's start pruning after 3 epochs and train our model for 6 epochs to have the same total amount of training as before

In [None]:
learn.fit(10, cbs=sp_cb)

Even though they are often considered as different pruning methods, those 3 schedules can be captured by the same Callback. Here is how the sparsity in the network evolves for those methods;

Let's take an example here. Let's say that we want to train our network for 3 epochs without pruning and then 7 epochs with pruning.

In [None]:
#| include: false
train = np.zeros(300)
prune = np.linspace(0,1, 700) 

Then this is what our different pruning schedules will look like:

In [None]:
#| echo: false
fig, ax = plt.subplots(1, 1, figsize=(8,5), dpi=100)
fig.patch.set_alpha(0.)
ax.patch.set_alpha(0.)
plt.plot(np.concatenate([train, sched_iterative(0,90, prune)]), label='Iterative', linestyle='-.', c='#89d6c9')
plt.plot(np.concatenate([train, [sched_oneshot(0,90, p) for p in prune]]), label='One-Shot', linestyle=':', c='#89d6c9')
plt.plot(np.concatenate([train, sched_agp(0,90, prune)]), label='Gradual', c='#89d6c9')
ax.spines['bottom'].set_color('#808080')
ax.spines['top'].set_color('#808080') 
ax.spines['right'].set_color('#808080')
ax.spines['left'].set_color('#808080')
ax.tick_params(axis='x', colors='#808080')
ax.tick_params(axis='y', colors='#808080')
ax.yaxis.label.set_color('#808080')
ax.xaxis.label.set_color('#808080')
plt.legend(framealpha=0.3);

**You can also come up with your own pruning schedule !**