In [None]:
from fastai.vision.all import *

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| include: false
from fastai.vision.all import *
from fastai.callback.all import *
from fasterai.sparse.all import *

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

In [None]:
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

The most important part of our `Callback` happens in `before_batch`. There, we first compute the sparsity of our network according to our schedule and then we remove the parameters accordingly.

In [None]:
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In [None]:
learn.fit_one_cycle(5)

epoch,train_loss,valid_loss,accuracy,time
0,0.760291,0.818725,0.805142,00:02
1,0.401118,0.317854,0.88498,00:02
2,0.247798,0.263363,0.893775,00:02
3,0.133845,0.215846,0.927605,00:02
4,0.069763,0.152332,0.94452,00:02


Let's now try adding some sparsity in our model

In [None]:
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

The `SparsifyCallback` requires a new argument compared to the `Sparsifier`. Indeed, we need to know the pruning schedule that we should follow during training in order to prune the parameters accordingly.

You can use any scheduling function already [available](https://docs.fast.ai/callback.schedule.html#Annealing) in fastai or come up with your own ! For more information about the pruning schedules, take a look at the [Schedules section](https://nathanhubens.github.io/fasterai/schedules.html).

In [None]:
sp_cb = SparsifyCallback(sparsity=50, granularity='weight', context='local', criteria=large_final, schedule=one_cycle)

In [None]:
learn.fit_one_cycle(5, cbs=sp_cb)

Pruning of weight until a sparsity of [50]%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.696515,0.492782,0.834912,00:02
1,0.410025,0.238266,0.897835,00:02
2,0.236258,0.178887,0.928281,00:02
3,0.135842,0.174733,0.930988,00:02
4,0.068706,0.162616,0.941137,00:02


Sparsity at the end of epoch 0: [1.96]%
Sparsity at the end of epoch 1: [20.07]%
Sparsity at the end of epoch 2: [45.86]%
Sparsity at the end of epoch 3: [49.74]%
Sparsity at the end of epoch 4: [50.0]%
Final Sparsity: [50.0]%

Sparsity Report:
--------------------------------------------------------------------------------
Layer                Type            Params     Zeros      Sparsity  
--------------------------------------------------------------------------------
Layer 2              Conv2d          9,408      4,704         50.00%
Layer 8              Conv2d          36,864     18,432        50.00%
Layer 11             Conv2d          36,864     18,432        50.00%
Layer 14             Conv2d          36,864     18,432        50.00%
Layer 17             Conv2d          36,864     18,432        50.00%
Layer 21             Conv2d          73,728     36,864        50.00%
Layer 24             Conv2d          147,456    73,727        50.00%
Layer 27             Conv2d          8,1

Surprisingly, our network that is composed of $50 \%$ of zeroes performs reasonnably well when compared to our plain and dense network.

The `SparsifyCallback` also accepts a list of sparsities, corresponding to each layer of `layer_type` to be pruned. Below, we show how to prune only the intermediate layers of ResNet-18.

In [None]:
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In [None]:
sparsities = [0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 0, 0, 0, 0, 0, 0]

In [None]:
sp_cb = SparsifyCallback(sparsity=sparsities, granularity='weight', context='local', criteria=large_final, schedule=cos)

In [None]:
learn.fit_one_cycle(5, cbs=sp_cb)

Pruning of weight until a sparsity of [0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 0, 0, 0, 0, 0, 0]%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.692292,0.521045,0.857916,00:02
1,0.42707,0.239684,0.895805,00:02
2,0.276716,0.387131,0.866035,00:02
3,0.157218,0.189208,0.923545,00:02
4,0.09077,0.156103,0.941813,00:02


Sparsity at the end of epoch 0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 1: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 2: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 3: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 4: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Final Sparsity: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%

Sparsity Report:
--------------------------------------------------------------------------------
Layer                Type            Para

On top of that, the `SparsifyCallback`can also take many optionnal arguments: 

- `lth`: whether training using the Lottery Ticket Hypothesis, i.e. reset the weights to their original value at each pruning step (more information in the Lottery Ticket Hypothesis section)
- `rewind_epoch`: the epoch used as a reference for the Lottery Ticket Hypothesis with Rewinding (default to 0)
- `reset_end`: whether you want to reset the weights to their original values after training (pruning masks are still applied)
- `save_tickets`: whether to save intermediate winning tickets.
- `model`: pass a model or a part of the model if you don't want to apply pruning on the whole model trained.
- `round_to`: if specified, the weights will be pruned to the closest multiple value of `round_to`.
- `layer_type`: specify the type of layer that you want to apply pruning to (default to nn.Conv2d)`

For example, we correctly pruned the convolution layers of our model, but we could imagine pruning the Linear Layers of even only the BatchNorm ones !