In [None]:
from fastai.vision.all import *

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| include: false
from fastai.vision.all import *
from fastai.callback.all import *
from fasterai.sparse.all import *

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

In [None]:
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

The most important part of our `Callback` happens in `before_batch`. There, we first compute the sparsity of our network according to our schedule and then we remove the parameters accordingly.

In [None]:
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In [None]:
learn.fit_one_cycle(5)

epoch,train_loss,valid_loss,accuracy,time
0,0.704975,2.227935,0.759134,00:03
1,0.40751,0.294927,0.885656,00:03
2,0.220607,0.27574,0.897158,00:03
3,0.128131,0.256966,0.906631,00:03
4,0.073029,0.234447,0.913396,00:03


Let's now try adding some sparsity in our model

In [None]:
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

The `SparsifyCallback` requires a new argument compared to the `Sparsifier`. Indeed, we need to know the pruning schedule that we should follow during training in order to prune the parameters accordingly.

You can use any scheduling function already [available](https://docs.fast.ai/callback.schedule.html#Annealing) in fastai or come up with your own ! For more information about the pruning schedules, take a look at the [Schedules section](https://nathanhubens.github.io/fasterai/schedules.html).

In [None]:
sp_cb = SparsifyCallback(sparsity=50, granularity='weight', context='local', criteria=large_final, schedule=one_cycle)

In [None]:
learn.fit_one_cycle(5, cbs=sp_cb)

Pruning of weight until a sparsity of [50]%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.69716,1.022295,0.79161,00:05
1,0.375767,0.320433,0.855886,00:05
2,0.228717,0.259796,0.893775,00:05
3,0.139725,0.212259,0.913396,00:05
4,0.08305,0.20872,0.926928,00:05


Sparsity at the end of epoch 0: [1.96]%
Sparsity at the end of epoch 1: [20.07]%
Sparsity at the end of epoch 2: [45.86]%
Sparsity at the end of epoch 3: [49.74]%
Sparsity at the end of epoch 4: [50.0]%
Final Sparsity: [50.0]%
Sparsity in Conv2d 2: 50.00%
Sparsity in Conv2d 8: 50.00%
Sparsity in Conv2d 11: 50.00%
Sparsity in Conv2d 14: 50.00%
Sparsity in Conv2d 17: 50.00%
Sparsity in Conv2d 21: 50.00%
Sparsity in Conv2d 24: 50.00%
Sparsity in Conv2d 27: 50.00%
Sparsity in Conv2d 30: 50.00%
Sparsity in Conv2d 33: 50.00%
Sparsity in Conv2d 37: 50.00%
Sparsity in Conv2d 40: 50.00%
Sparsity in Conv2d 43: 50.00%
Sparsity in Conv2d 46: 50.00%
Sparsity in Conv2d 49: 50.00%
Sparsity in Conv2d 53: 50.00%
Sparsity in Conv2d 56: 50.00%
Sparsity in Conv2d 59: 50.00%
Sparsity in Conv2d 62: 50.00%
Sparsity in Conv2d 65: 50.00%


Surprisingly, our network that is composed of $50 \%$ of zeroes performs reasonnably well when compared to our plain and dense network.

The `SparsifyCallback` also accepts a list of sparsities, corresponding to each layer of `layer_type` to be pruned. Below, we show how to prune only the intermediate layers of ResNet-18.

In [None]:
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In [None]:
sparsities = [0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 0, 0, 0, 0, 0, 0]

In [None]:
sp_cb = SparsifyCallback(sparsity=sparsities, granularity='weight', context='local', criteria=large_final, schedule=cos)

In [None]:
learn.fit_one_cycle(5, cbs=sp_cb)

Pruning of weight until a sparsity of [0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 0, 0, 0, 0, 0, 0]%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.73165,0.5704,0.811908,00:07
1,0.396108,0.262083,0.895805,00:07
2,0.250992,0.210679,0.909337,00:07
3,0.132799,0.192091,0.925575,00:07
4,0.079732,0.159255,0.93843,00:07


Sparsity at the end of epoch 0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 1: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 2: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 3: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 4: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Final Sparsity: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity in Conv2d 2: 0.00%
Sparsity in Conv2d 8: 0.00%
Sparsity in Conv2d 11: 0.00%
Sparsity in Conv2d 14: 0.00%
Sparsity in Conv2d 17: 0.0

On top of that, the `SparsifyCallback`can also take many optionnal arguments: 

- `lth`: whether training using the Lottery Ticket Hypothesis, i.e. reset the weights to their original value at each pruning step (more information in the Lottery Ticket Hypothesis section)
- `rewind_epoch`: the epoch used as a reference for the Lottery Ticket Hypothesis with Rewinding (default to 0)
- `reset_end`: whether you want to reset the weights to their original values after training (pruning masks are still applied)
- `save_tickets`: whether to save intermediate winning tickets.
- `model`: pass a model or a part of the model if you don't want to apply pruning on the whole model trained.
- `round_to`: if specified, the weights will be pruned to the closest multiple value of `round_to`.
- `layer_type`: specify the type of layer that you want to apply pruning to (default to nn.Conv2d)`

For example, we correctly pruned the convolution layers of our model, but we could imagine pruning the Linear Layers of even only the BatchNorm ones !