In [None]:
#| default_exp prune.prune_callback

In [None]:
#| include: false
from nbdev.showdoc import *

In [None]:
#| export
from fastai.vision.all import *
from fastai.callback.all import *
from fasterai.prune.pruner import *
from fasterai.core.criteria import *
from fasterai.core.schedule import *

import torch
import torch.nn as nn
import torch.nn.functional as F

## Overview

The `PruneCallback` integrates structured pruning into the fastai training loop. Unlike sparsification (which zeros weights), pruning physically removes network structures (filters, channels) to reduce model size and computation.

**Key Differences from SparsifyCallback:**
- Removes structures entirely (not just zeros)
- Uses `torch-pruning` library for dependency handling
- Supports various pruning criteria and schedules

In [None]:
#| export
class PruneCallback(Callback):
    def __init__(self, pruning_ratio, schedule, context, criteria, *args, **kwargs):
        store_attr()
        self.sparsity_levels = []
        self.extra_args = args
        self.extra_kwargs = kwargs

    def before_fit(self) -> None:
        "Setup pruner before training"
        n_batches_per_epoch = len(self.learn.dls.train)
        total_training_steps = n_batches_per_epoch * self.learn.n_epoch
        self.pruning_ratio = self.pruning_ratio/100 if self.pruning_ratio>1 else self.pruning_ratio
        
        # Validate pruning_ratio is in valid range
        if not (0 < self.pruning_ratio <= 1):
            raise ValueError(f"pruning_ratio must be in range (0, 1], got {self.pruning_ratio}")

        self.example_inputs, _ = self.learn.dls.one_batch()
        self.sparsity_levels = self.schedule._scheduler(self.pruning_ratio, total_training_steps)
        
        self.pruner = Pruner(
        self.learn.model,
        criteria=self.criteria,
        pruning_ratio=self.pruning_ratio, 
        context=self.context,
        iterative_steps= total_training_steps, 
        schedule=self.schedule._scheduler,
        *self.extra_args, 
        **self.extra_kwargs
        )
        
    def before_step(self) -> None:
        "Apply pruning before optimizer step"
        if self.training: 
            self.pruner.prune_model()

    def after_epoch(self) -> None:
        "Log sparsity after each epoch"
        completed_steps = (self.epoch + 1) * len(self.learn.dls.train)
        # Bounds check for sparsity_levels access
        if completed_steps > 0 and completed_steps <= len(self.sparsity_levels):
            current_sparsity = self.sparsity_levels[completed_steps - 1]
            print(f'Sparsity at the end of epoch {self.epoch}: {current_sparsity*100:.2f}%')

In [None]:
show_doc(PruneCallback)

**Parameters:**

- `pruning_ratio`: Target ratio of parameters to remove (0-1 or 0-100). Values >1 are treated as percentages.
- `schedule`: When to prune (from `fasterai.core.schedule`). Controls how pruning progresses over training.
- `context`: `'local'` (per-layer pruning) or `'global'` (across entire model).
- `criteria`: How to select what to prune (from `fasterai.core.criteria`).

---

## Usage Example

```python
from fasterai.prune.prune_callback import PruneCallback
from fasterai.core.schedule import agp
from fasterai.core.criteria import large_final

# Prune 30% of parameters using automated gradual pruning schedule
cb = PruneCallback(
    pruning_ratio=30,        # Remove 30% of parameters
    schedule=agp,            # Gradual pruning (cubic decay)
    context='global',        # Prune globally across all layers
    criteria=large_final     # Keep weights with largest magnitude
)

learn.fit(10, cbs=[cb])
```

---

## See Also

- [Pruner](pruner.html) - Core structured pruning class used by this callback
- [Schedules](../core/schedules.html) - Control pruning progression during training
- [Criteria](../core/criteria.html) - Importance measures for selecting filters to prune
- [SparsifyCallback](../sparse/sparsify_callback.html) - Unstructured pruning alternative