In [None]:
#| default_exp prune.prune_callback

In [None]:
#| include: false
from nbdev.showdoc import *

In [None]:
#| export
from fastai.vision.all import *
from fastai.callback.all import *
from fasterai.prune.pruner import *
from fasterai.core.criteria import *
from fasterai.core.schedule import *

import torch
import torch.nn as nn
import torch.nn.functional as F

## Overview

The `PruneCallback` integrates structured pruning into the fastai training loop. Unlike sparsification (which zeros weights), pruning physically removes network structures (filters, channels) to reduce model size and computation.

**Key Differences from SparsifyCallback:**
- Removes structures entirely (not just zeros)
- Uses `torch-pruning` library for dependency handling
- Supports various pruning criteria and schedules

In [None]:
#| export
class PruneCallback(Callback):
    def __init__(self, pruning_ratio, schedule, context, criteria, *args, **kwargs):
        store_attr()
        self.sparsity_levels = []
        self.extra_args = args
        self.extra_kwargs = kwargs

    def _build_pruning_schedule(self, sched_func):
        "Create a schedule function compatible with torch-pruning's Pruner"
        start_val, end_val = self.schedule.start_val, self.schedule.end_val
        def scheduler(pruning_ratio, steps, start=start_val, end=end_val):
            return [
                sched_func(start, end, i / float(steps)) * pruning_ratio
                for i in range(steps + 1)
            ]
        return scheduler

    def before_fit(self) -> None:
        "Setup pruner before training"
        n_batches_per_epoch = len(self.learn.dls.train)
        total_training_steps = n_batches_per_epoch * self.learn.n_epoch
        self.pruning_ratio = self.pruning_ratio/100 if self.pruning_ratio>1 else self.pruning_ratio
        
        # Validate pruning_ratio is in valid range
        if not (0 < self.pruning_ratio <= 1):
            raise ValueError(f"pruning_ratio must be in range (0, 1], got {self.pruning_ratio}")

        self.example_inputs, _ = self.learn.dls.one_batch()
        
        # Build schedule function for torch-pruning compatibility
        pruning_schedule = self._build_pruning_schedule(self.schedule.sched_func)
        self.sparsity_levels = pruning_schedule(self.pruning_ratio, total_training_steps)
        
        self.pruner = Pruner(
            self.learn.model,
            criteria=self.criteria,
            pruning_ratio=self.pruning_ratio, 
            context=self.context,
            iterative_steps=total_training_steps, 
            schedule=pruning_schedule,
            *self.extra_args, 
            **self.extra_kwargs
        )
        
    def before_step(self) -> None:
        "Apply pruning before optimizer step"
        if self.training: 
            self.pruner.prune_model()

    def after_epoch(self) -> None:
        "Log sparsity after each epoch"
        completed_steps = (self.epoch + 1) * len(self.learn.dls.train)
        # Bounds check for sparsity_levels access
        if completed_steps > 0 and completed_steps <= len(self.sparsity_levels):
            current_sparsity = self.sparsity_levels[completed_steps - 1]
            print(f'Sparsity at the end of epoch {self.epoch}: {current_sparsity*100:.2f}%')

In [None]:
show_doc(PruneCallback)

Found permutation search CUDA kernels
[ASP][Info] permutation_search_kernels can be imported.


---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/prune/prune_callback.py#L18){target="_blank" style="float:right; font-size:smaller"}

### PruneCallback

```python

def PruneCallback(
    pruning_ratio, schedule, context, criteria, args:VAR_POSITIONAL, kwargs:VAR_KEYWORD
):


```

*Basic class handling tweaks of the training loop by changing a `Learner` in various events*

**Parameters:**

- `pruning_ratio`: Target ratio of parameters to remove (0-1 or 0-100). Values >1 are treated as percentages.
- `schedule`: When to prune (from `fasterai.core.schedule`). Controls how pruning progresses over training.
- `context`: `'local'` (per-layer pruning) or `'global'` (across entire model).
- `criteria`: How to select what to prune (from `fasterai.core.criteria`).

---

## Usage Example

```python
from fasterai.prune.prune_callback import PruneCallback
from fasterai.core.schedule import agp
from fasterai.core.criteria import large_final

# Prune 30% of parameters using automated gradual pruning schedule
cb = PruneCallback(
    pruning_ratio=30,        # Remove 30% of parameters
    schedule=agp,            # Gradual pruning (cubic decay)
    context='global',        # Prune globally across all layers
    criteria=large_final     # Keep weights with largest magnitude
)

learn.fit(10, cbs=[cb])
```

In [None]:
#| hide
from fastcore.test import *

# Construction with valid params
cb = PruneCallback(
    pruning_ratio=30,
    schedule=agp,
    context='global',
    criteria=large_final
)
test_eq(cb.pruning_ratio, 30)
test_eq(cb.context, 'global')

In [None]:
#| hide
#| slow
# Full training with PruneCallback — verify parameter reduction
from torch.utils.data import TensorDataset
from fastai.data.core import DataLoaders

_model = nn.Sequential(
    nn.Conv2d(3, 16, 3, padding=1), nn.BatchNorm2d(16), nn.ReLU(),
    nn.Conv2d(16, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
    nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(32, 10)
)
_params_before = sum(p.numel() for p in _model.parameters())

_X = torch.randn(64, 3, 8, 8)
_y = torch.randint(0, 10, (64,))
_dls = DataLoaders.from_dsets(
    TensorDataset(_X[:48], _y[:48]),
    TensorDataset(_X[48:], _y[48:]),
    bs=16, device='cpu'
)

_cb = PruneCallback(pruning_ratio=30, schedule=one_shot, context='local', criteria=large_final)
_learn = Learner(_dls, _model, loss_func=nn.CrossEntropyLoss(), cbs=[_cb])
_learn.fit(3)

_params_after = sum(p.numel() for p in _model.parameters())
assert _params_after < _params_before, f"Expected params to decrease: {_params_before} → {_params_after}"

Ignoring output layer: 8
Total ignored layers: 1


epoch,train_loss,valid_loss,time
0,2.315615,2.329493,00:00
1,2.315022,2.330559,00:00
2,2.31449,2.331829,00:00


Sparsity at the end of epoch 0: 30.00%
Sparsity at the end of epoch 1: 30.00%
Sparsity at the end of epoch 2: 30.00%


---

## See Also

- [Pruner](pruner.html) - Core structured pruning class used by this callback
- [Schedules](../core/schedules.html) - Control pruning progression during training
- [Criteria](../core/criteria.html) - Importance measures for selecting filters to prune
- [SparsifyCallback](../sparse/sparsify_callback.html) - Unstructured pruning alternative