In [None]:
#| default_exp sparse.sparsify_callback

In [None]:
#| include: false
from nbdev.showdoc import *

In [None]:
#| export
from fastai.vision.all import *
from fastai.callback.all import *
from fasterai.sparse.sparsifier import *
from fasterai.core.criteria import *
from fasterai.core.schedule import *
from typing import Callable, Optional, Union, List, Tuple, Type

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
#| export
class SparsifyCallback(Callback):
    def __init__(self, 
                 sparsity: Union[float, List[float]],    # Target sparsity level(s)
                 granularity: str,                       # Type of pruning granularity (e.g., 'weight', 'filter')
                 context: str,                           # Pruning context ('global' or 'local')
                 criteria: Criteria,                     # Criteria for determining weights to keep
                 schedule: Schedule,                     # Pruning schedule to use
                 lth: bool = False,                      # Whether to use Lottery Ticket Hypothesis approach
                 rewind_epoch: int = 0,                  # Epoch to rewind weights to for LTH
                 reset_end: bool = False,                # Whether to reset weights after pruning
                 save_tickets: bool = False,             # Whether to save pruned models as "winning tickets"
                 model: Optional[nn.Module] = None,      # Model to sparsify (if None, uses learn.model)
                 round_to: Optional[int] = None,         # Round pruning to multiple of this value
                 nm: bool = False,                       # Whether to use N:M structured sparsity
                 layer_type: Type[nn.Module] = nn.Conv2d # Layer type to apply pruning to
    ):
        "Callback to sparsify model during training according to a schedule"
        store_attr()
        self.sparsity = listify(self.sparsity)

    def before_fit(self):
        "Setup sparsifier before training"
        print(f'Pruning of {self.granularity} until a sparsity of {self.sparsity}%')
        assert self.schedule.start_pct*self.n_epoch>=self.rewind_epoch, 'You must rewind to an epoch before the start of the pruning process'
        model = self.model or self.learn.model
        self.sparsifier = Sparsifier(model, self.granularity, self.context, self.criteria, self.nm, self.layer_type)

    def before_epoch(self):
        "Save weights at rewind epoch if using LTH"
        if self.epoch == self.rewind_epoch:
            print(f'Saving Weights at epoch {self.epoch}')
            self.sparsifier._save_weights()

    def before_batch(self):
        "Update sparsity level and potentially apply pruning"
        self.current_sparsity = self.schedule(self.sparsity, round(self.pct_train,3))
        if self.schedule.pruned and self.training:
            if self.lth and self.save_tickets:
                print('Saving Intermediate Ticket')
                self.sparsifier.save_model(f'winning_ticket_{self.previous_sparsity[0]:.2f}.pth', self.learn.model)
            self.sparsifier.sparsify_model(self.current_sparsity, self.round_to)

    def after_step(self):
        "Handle post-pruning steps"
        if self.lth and self.schedule.pruned:
            print(f'Resetting Weights to their epoch {self.rewind_epoch} values')
            self.sparsifier._reset_weights(self.learn.model)
        self.schedule.after_pruned()
        self.sparsifier._apply_masks()

    def after_epoch(self):
        "Log sparsity after each epoch"
        sparsity_str = [float(f"%0.2f"%sp) for sp in self.current_sparsity]
        print(f'Sparsity at the end of epoch {self.epoch}: {sparsity_str}%')

    def after_fit(self):
        "Clean up after training"
        if self.save_tickets:
            print('Saving Final Ticket')
            self.sparsifier.save_model(f'winning_ticket_{self.previous_sparsity[0]:.2f}.pth', self.learn.model)
        print(f'Final Sparsity: {self.schedule.current_sparsity:}%')
        if self.reset_end: self.sparsifier._reset_weights()
        self.sparsifier._clean_buffers()
        self.schedule.reset()
        self.sparsifier.print_sparsity()

In [None]:
show_doc(SparsifyCallback)

---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/sparse/sparsify_callback.py#L18){target="_blank" style="float:right; font-size:smaller"}

### SparsifyCallback

>      SparsifyCallback (sparsity:Union[float,List[float]], granularity:str,
>                        context:str, criteria:fasterai.core.criteria.Criteria,
>                        schedule:fasterai.core.schedule.Schedule,
>                        lth:bool=False, rewind_epoch:int=0,
>                        reset_end:bool=False, save_tickets:bool=False,
>                        model:Optional[torch.nn.modules.module.Module]=None,
>                        round_to:Optional[int]=None, nm:bool=False,
>                        layer_type:Type[torch.nn.modules.module.Module]=<class
>                        'torch.nn.modules.conv.Conv2d'>)

*Basic class handling tweaks of the training loop by changing a `Learner` in various events*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| sparsity | Union |  | Target sparsity level(s) |
| granularity | str |  | Type of pruning granularity (e.g., 'weight', 'filter') |
| context | str |  | Pruning context ('global' or 'local') |
| criteria | Criteria |  | Criteria for determining weights to keep |
| schedule | Schedule |  | Pruning schedule to use |
| lth | bool | False | Whether to use Lottery Ticket Hypothesis approach |
| rewind_epoch | int | 0 | Epoch to rewind weights to for LTH |
| reset_end | bool | False | Whether to reset weights after pruning |
| save_tickets | bool | False | Whether to save pruned models as "winning tickets" |
| model | Optional | None | Model to sparsify (if None, uses learn.model) |
| round_to | Optional | None | Round pruning to multiple of this value |
| nm | bool | False | Whether to use N:M structured sparsity |
| layer_type | Type | Conv2d | Layer type to apply pruning to |

The most important part of our `Callback` happens in `before_batch`. There, we first compute the sparsity of our network according to our schedule and then we remove the parameters accordingly.

The `SparsifyCallback` requires a new argument compared to the `Sparsifier`. Indeed, we need to know the pruning schedule that we should follow during training in order to prune the parameters accordingly.

You can use any scheduling function already [available](https://docs.fast.ai/callback.schedule.html#Annealing) in fastai or come up with your own ! For more information about the pruning schedules, take a look at the [Schedules section](https://nathanhubens.github.io/fasterai/schedules.html).

On top of that, the `SparsifyCallback`can also take many optionnal arguments: 

- `lth`: whether training using the Lottery Ticket Hypothesis, i.e. reset the weights to their original value at each pruning step (more information in the Lottery Ticket Hypothesis section)
- `rewind_epoch`: the epoch used as a reference for the Lottery Ticket Hypothesis with Rewinding (default to 0)
- `reset_end`: whether you want to reset the weights to their original values after training (pruning masks are still applied)
- `save_tickets`: whether to save intermediate winning tickets.
- `model`: pass a model or a part of the model if you don't want to apply pruning on the whole model trained.
- `round_to`: if specified, the weights will be pruned to the closest multiple value of `round_to`.
- `layer_type`: specify the type of layer that you want to apply pruning to (default to nn.Conv2d)`