In [None]:
#| default_exp sparse.sparsify_callback

In [None]:
#| include: false
from nbdev.showdoc import *

In [None]:
#| export
from __future__ import annotations
from fastai.vision.all import *
from fastai.callback.all import *
from fasterai.sparse.sparsifier import *
from fasterai.core.criteria import *
from fasterai.core.schedule import *
from typing import Callable, Type

import torch
import torch.nn as nn
import torch.nn.functional as F

## Overview

The `SparsifyCallback` integrates weight sparsification into the fastai training loop. Unlike pruning (which removes structures), sparsification zeros out individual weights while maintaining the original network shape.

**Key Features:**
- Gradual sparsification according to a schedule
- Support for Lottery Ticket Hypothesis (LTH) training
- Multiple granularity levels (weight, vector, kernel, filter)
- Global or local sparsification context

In [None]:
#| export
class SparsifyCallback(Callback):
    def __init__(self, 
                 sparsity: float | dict[str, float],        # Target sparsity (float) or per-layer dict
                 granularity: str,                           # Type of pruning granularity (e.g., 'weight', 'filter')
                 context: str,                               # Pruning context ('global' or 'local')
                 criteria: Criteria,                         # Criteria for determining weights to keep
                 schedule: Schedule,                         # Pruning schedule to use
                 lth: bool = False,                          # Whether to use Lottery Ticket Hypothesis approach
                 rewind_epoch: int = 0,                      # Epoch to rewind weights to for LTH
                 reset_end: bool = False,                    # Whether to reset weights after pruning
                 save_tickets: bool = False,                 # Whether to save pruned models as "winning tickets"
                 model: nn.Module | None = None,             # Model to sparsify (if None, uses learn.model)
                 round_to: int | None = None,                # Round pruning to multiple of this value
                 nm: bool = False,                           # Whether to use N:M structured sparsity
                 layer_type: Type[nn.Module] = nn.Conv2d     # Layer type to apply pruning to
    ):
        "Callback to sparsify model during training according to a schedule"
        store_attr()
        self.current_sparsity = 0.0

    def _sparsity_value(self) -> float:
        "Extract a single sparsity value for logging/saving (first value if dict)"
        if isinstance(self.current_sparsity, dict):
            return next(iter(self.current_sparsity.values()))
        return self.current_sparsity

    def before_fit(self) -> None:
        "Setup sparsifier before training"
        print(f'Pruning of {self.granularity} until a sparsity of {self.sparsity}%')
        assert self.schedule.start_pct*self.n_epoch>=self.rewind_epoch, 'You must rewind to an epoch before the start of the pruning process'
        model = self.model or self.learn.model
        self.sparsifier = Sparsifier(model, self.granularity, self.context, self.criteria, self.nm, self.layer_type)

    def before_epoch(self) -> None:
        "Save weights at rewind epoch if using LTH"
        if self.epoch == self.rewind_epoch:
            print(f'Saving Weights at epoch {self.epoch}')
            self.sparsifier._save_weights()

    def before_batch(self) -> None:
        "Update sparsity level and potentially apply pruning"
        progress = self.schedule.progress(round(self.pct_train, 3))
        
        # Compute current sparsity: float * progress or {layer: sp * progress}
        if isinstance(self.sparsity, dict):
            self.current_sparsity = {k: v * progress for k, v in self.sparsity.items()}
        else:
            self.current_sparsity = self.sparsity * progress
        
        if self.schedule.changed and self.training:
            if self.lth and self.save_tickets:
                print('Saving Intermediate Ticket')
                self.sparsifier.save_model(f'winning_ticket_{self._sparsity_value():.2f}.pth', self.learn.model)
            self.sparsifier.sparsify_model(self.current_sparsity, self.round_to)

    def after_step(self) -> None:
        "Handle post-pruning steps"
        if self.lth and self.schedule.changed:
            print(f'Resetting Weights to their epoch {self.rewind_epoch} values')
            self.sparsifier._reset_weights(self.learn.model)
        self.schedule.after_step()
        self.sparsifier._apply_masks()

    def after_epoch(self) -> None:
        "Log sparsity after each epoch"
        if isinstance(self.current_sparsity, dict):
            avg_sparsity = sum(self.current_sparsity.values()) / len(self.current_sparsity)
            print(f'Sparsity at the end of epoch {self.epoch}: avg={avg_sparsity:.2f}%')
        else:
            print(f'Sparsity at the end of epoch {self.epoch}: {self.current_sparsity:.2f}%')

    def after_fit(self) -> None:
        "Clean up after training"
        if self.save_tickets:
            print('Saving Final Ticket')
            self.sparsifier.save_model(f'winning_ticket_{self._sparsity_value():.2f}.pth', self.learn.model)
        
        if isinstance(self.current_sparsity, dict):
            print(f'Final Sparsity: {self.current_sparsity}')
        else:
            print(f'Final Sparsity: {self.current_sparsity:.2f}%')
        
        if self.reset_end: self.sparsifier._reset_weights()
        self.sparsifier._clean_buffers()
        self.schedule.reset()
        self.sparsifier.print_sparsity()

In [None]:
show_doc(SparsifyCallback)

Found permutation search CUDA kernels
[ASP][Info] permutation_search_kernels can be imported.


---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/sparse/sparsify_callback.py#L20){target="_blank" style="float:right; font-size:smaller"}

### SparsifyCallback

```python

def SparsifyCallback(
    sparsity:float | dict[str, float], # Target sparsity (float) or per-layer dict
    granularity:str, # Type of pruning granularity (e.g., 'weight', 'filter')
    context:str, # Pruning context ('global' or 'local')
    criteria:Criteria, # Criteria for determining weights to keep
    schedule:Schedule, # Pruning schedule to use
    lth:bool=False, # Whether to use Lottery Ticket Hypothesis approach
    rewind_epoch:int=0, # Epoch to rewind weights to for LTH
    reset_end:bool=False, # Whether to reset weights after pruning
    save_tickets:bool=False, # Whether to save pruned models as "winning tickets"
    model:nn.Module | None=None, # Model to sparsify (if None, uses learn.model)
    round_to:int | None=None, # Round pruning to multiple of this value
    nm:bool=False, # Whether to use N:M structured sparsity
    layer_type:Type[nn.Module]=<class 'torch.nn.modules.conv.Conv2d'>, # Layer type to apply pruning to
):


```

*Basic class handling tweaks of the training loop by changing a `Learner` in various events*

The most important part of our `Callback` happens in `before_batch`. There, we first compute the sparsity of our network according to our schedule and then we remove the parameters accordingly.

The `SparsifyCallback` requires a new argument compared to the `Sparsifier`. Indeed, we need to know the pruning schedule that we should follow during training in order to prune the parameters accordingly.

You can use any scheduling function already [available](https://docs.fast.ai/callback.schedule.html#Annealing) in fastai or come up with your own ! For more information about the pruning schedules, take a look at the [Schedules section](https://nathanhubens.github.io/fasterai/schedules.html).

On top of that, the `SparsifyCallback`can also take many optionnal arguments: 

- `lth`: whether training using the Lottery Ticket Hypothesis, i.e. reset the weights to their original value at each pruning step (more information in the Lottery Ticket Hypothesis section)
- `rewind_epoch`: the epoch used as a reference for the Lottery Ticket Hypothesis with Rewinding (default to 0)
- `reset_end`: whether you want to reset the weights to their original values after training (pruning masks are still applied)
- `save_tickets`: whether to save intermediate winning tickets.
- `model`: pass a model or a part of the model if you don't want to apply pruning on the whole model trained.
- `round_to`: if specified, the weights will be pruned to the closest multiple value of `round_to`.
- `layer_type`: specify the type of layer that you want to apply pruning to (default to nn.Conv2d)`

---

## Usage Example

```python
from fasterai.sparse.sparsify_callback import SparsifyCallback
from fasterai.core.schedule import cos
from fasterai.core.criteria import large_final

# Gradually sparsify to 50% using cosine schedule
cb = SparsifyCallback(
    sparsity=50,
    granularity='weight',
    context='global',
    criteria=large_final,
    schedule=cos
)

learn.fit(10, cbs=[cb])
```

### Per-Layer Sparsity with Dict

```python
# Different sparsity targets for different layers
cb = SparsifyCallback(
    sparsity={'conv1': 30, 'layer1': 50, 'layer2': 70},
    granularity='weight',
    context='local',
    criteria=large_final,
    schedule=cos
)
```

### With Lottery Ticket Hypothesis

```python
# Train with LTH - rewind weights to epoch 2 values after each pruning step
cb = SparsifyCallback(
    sparsity=90,
    granularity='weight',
    context='global', 
    criteria=large_final,
    schedule=one_cycle,
    lth=True,
    rewind_epoch=2
)

learn.fit(20, cbs=[cb])
```

In [None]:
#| hide
from fastcore.test import *

# Construction with valid params (objects, not strings)
cb = SparsifyCallback(
    sparsity=50, granularity='weight', context='local',
    criteria=large_final, schedule=one_shot
)
test_eq(cb.sparsity, 50)
test_eq(cb.granularity, 'weight')
test_eq(cb.context, 'local')
test_eq(cb.current_sparsity, 0.0)

# Dict-based sparsity
cb_dict = SparsifyCallback(
    sparsity={'conv1': 30, 'conv2': 60}, granularity='weight',
    context='local', criteria=large_final, schedule=lin
)
assert isinstance(cb_dict.sparsity, dict)
test_eq(cb_dict.sparsity['conv1'], 30)
test_eq(cb_dict.sparsity['conv2'], 60)

# _sparsity_value helper with float
cb.current_sparsity = 42.0
test_eq(cb._sparsity_value(), 42.0)

# _sparsity_value helper with dict
cb_dict.current_sparsity = {'a': 10.0, 'b': 20.0}
test_eq(cb_dict._sparsity_value(), 10.0)

In [None]:
#| hide
#| slow
# Full training loop with SparsifyCallback â€” verify sparsity is applied
from torch.utils.data import TensorDataset
from fastai.data.core import DataLoaders

_model = nn.Sequential(
    nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(),
    nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(16, 10)
)
_X = torch.randn(64, 3, 8, 8)
_y = torch.randint(0, 10, (64,))
_dls = DataLoaders.from_dsets(
    TensorDataset(_X[:48], _y[:48]),
    TensorDataset(_X[48:], _y[48:]),
    bs=16, device='cpu'
)

_cb = SparsifyCallback(sparsity=50, granularity='weight', context='local',
                       criteria=large_final, schedule=one_shot)
_learn = Learner(_dls, _model, loss_func=nn.CrossEntropyLoss(), cbs=[_cb])
_learn.fit(3)

# Verify sparsification was applied to conv layers
for m in _model.modules():
    if isinstance(m, nn.Conv2d):
        _sp = (m.weight == 0).float().mean().item() * 100
        test_close(_sp, 50.0, eps=10.0)

Pruning of weight until a sparsity of 50%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,time
0,2.31583,2.356278,00:00
1,2.313395,2.356196,00:00
2,2.310454,2.35688,00:00


Sparsity at the end of epoch 0: 0.00%
Sparsity at the end of epoch 1: 50.00%
Sparsity at the end of epoch 2: 50.00%
Final Sparsity: 50.00%

Sparsity Report:
--------------------------------------------------------------------------------
Layer                          Type            Params     Zeros      Sparsity  
--------------------------------------------------------------------------------
0                              Conv2d          432        216           50.00%
--------------------------------------------------------------------------------
Overall                        all             432        216           50.00%


---

## See Also

- [Sparsifier](sparsifier.html) - Core sparsification class used by this callback
- [Schedules](../core/schedules.html) - Control sparsification progression (one_shot, agp, etc.)
- [Criteria](../core/criteria.html) - Importance measures (large_final, movement, etc.)
- [Lottery Ticket Tutorial](../tutorials/sparse/lottery_ticket.html) - Finding winning tickets with sparsification