In [None]:
#| default_exp regularize.regularize_callback

In [None]:
#| include: false
from nbdev.showdoc import *

In [None]:
#| export
from __future__ import annotations
import warnings

from fastai.callback.all import *
from fastcore.basics import store_attr, listify
from fasterai.core.criteria import *
from fasterai.core.granularity import *
from fasterai.core.schedule import *

import torch
import torch.nn as nn
from typing import Type

## Overview

The `RegularizeCallback` applies structured regularization during training to encourage weight sparsity at various granularities. This is useful as a pre-pruning step: by regularizing groups of weights toward zero during training, subsequent pruning can remove more parameters with less accuracy loss.

**Key Features:**
- Supports multiple granularity levels (`'weight'`, `'vector'`, `'kernel'`, `'filter'`)
- Compatible with any criteria from `fasterai.core.criteria`
- Optional scheduling to vary regularization strength over training

In [None]:
#| export
class RegularizeCallback(Callback):
    def __init__(self, 
                 criteria: Criteria | list[Criteria],            # Criteria(s) to use for regularization
                 granularity: str | list[str],                   # Granularity level(s) for grouping
                 weight: float = 0.01,                                 # Regularization weight
                 layer_types: Type | list[Type] = nn.Conv2d,     # Layer types to apply regularization to
                 schedule: Schedule | None = None,                  # Optional schedule for regularization weight
                 verbose: bool = False                                 # Whether to report regularization weight
    ):
        "Callback to apply regularization using criteria during training"
        store_attr()
        self.criteria = listify(criteria)
        self.granularity = listify(granularity)
        self.layer_types = listify(layer_types)
        self.current_weight = weight
        
    def before_batch(self) -> None:
        "Update regularization weight if scheduled"
        if self.schedule is not None:
            progress = self.schedule.progress(self.pct_train)
            self.current_weight = self.weight * progress
        
    def after_loss(self) -> None:
        "Apply regularization after computing the main loss"
        reg = self.get_norm()
        self.learn.loss_grad += reg
        self.learn.loss = self.learn.loss_grad.clone()
    
    def _iter_layers(self):
        "Iterate over matching layers with weights"
        for m in self.learn.model.modules():
            if any(isinstance(m, lt) for lt in self.layer_types) and hasattr(m, 'weight'):
                yield m
            
    def get_norm(self) -> torch.Tensor:
        "Compute regularization using the specified criteria and granularities"
        # Pre-filter modules once
        layers = list(self._iter_layers())
        
        layer_regs = []
        for crit in self.criteria:
            for g in self.granularity:
                for m in layers:
                    try:
                        scores = crit.f(m.weight)[None].abs().sum(Granularities.get_dim(m, g))
                        layer_regs.append(self.current_weight * scores.sum())
                    except (KeyError, ValueError) as e:
                        warnings.warn(f"Skipping regularization for {type(m).__name__}: {e}")
                    except RuntimeError as e:
                        warnings.warn(f"Runtime error in regularization for {type(m).__name__}: {e}")
        
        return torch.stack(layer_regs).sum() if layer_regs else torch.tensor(0.0)
    
    def after_epoch(self) -> None:
        "Report current regularization weight if verbose"
        if self.verbose:
            print(f"Current regularization weight: {self.current_weight:.6f}")

In [None]:
show_doc(RegularizeCallback)

Found permutation search CUDA kernels
[ASP][Info] permutation_search_kernels can be imported.


---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/regularize/regularize_callback.py#L21){target="_blank" style="float:right; font-size:smaller"}

### RegularizeCallback

```python

def RegularizeCallback(
    criteria:Criteria | list[Criteria], # Criteria(s) to use for regularization
    granularity:str | list[str], # Granularity level(s) for grouping
    weight:float=0.01, # Regularization weight
    layer_types:Type | list[Type]=<class 'torch.nn.modules.conv.Conv2d'>, # Layer types to apply regularization to
    schedule:Schedule | None=None, # Optional schedule for regularization weight
    verbose:bool=False, # Whether to report regularization weight
):


```

*Basic class handling tweaks of the training loop by changing a `Learner` in various events*

**Parameters:**
- `criteria`: Importance criteria to use for computing regularization (e.g., `large_final`)
- `granularity`: Level at which to group weights (`'weight'`, `'vector'`, `'kernel'`, `'filter'`)
- `weight`: Regularization coefficient (higher = stronger regularization)
- `layer_types`: Module types to regularize (default: `nn.Conv2d`)
- `schedule`: Optional schedule to vary regularization strength over training
- `verbose`: Print regularization weight after each epoch

---

## Usage Example

Apply filter-level L1 regularization to encourage entire filters to become unimportant (making them easier to prune later):

```python
from fasterai.regularize.regularize_callback import RegularizeCallback
from fasterai.core.criteria import large_final

# Apply L1 regularization at filter granularity
cb = RegularizeCallback(
    criteria=large_final,
    granularity='filter',
    weight=0.01,
    verbose=True
)

learn.fit(10, cbs=[cb])
```

**Typical Workflow:**
1. Train with `RegularizeCallback` to push unimportant filter groups toward zero
2. After training, use `PruneCallback` or `Pruner` to remove the zeroed-out structures
3. Fine-tune the pruned model to recover any lost accuracy

In [None]:
#| hide
from fastcore.test import *

# Single criteria + granularity
cb = RegularizeCallback(criteria=large_final, granularity='filter', weight=1e-4)
test_eq(cb.weight, 1e-4)
test_eq(cb.current_weight, 1e-4)
test_eq(len(cb.criteria), 1)
test_eq(len(cb.granularity), 1)

# List of criteria/granularities
cb_m = RegularizeCallback(
    criteria=[large_final, large_final],
    granularity=['filter', 'weight']
)
test_eq(len(cb_m.criteria), 2)
test_eq(len(cb_m.granularity), 2)

# Default layer_types is Conv2d (listified)
test_eq(len(cb.layer_types), 1)
assert nn.Conv2d in cb.layer_types

# Schedule is None by default
test_eq(cb.schedule, None)

In [None]:
#| hide
#| slow
# Training with RegularizeCallback â€” verify it runs without error
from torch.utils.data import TensorDataset
from fastai.data.core import DataLoaders
from fastai.learner import Learner

_model = nn.Sequential(
    nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(),
    nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(16, 10)
)

_X = torch.randn(64, 3, 8, 8)
_y = torch.randint(0, 10, (64,))
_dls = DataLoaders.from_dsets(
    TensorDataset(_X[:48], _y[:48]),
    TensorDataset(_X[48:], _y[48:]),
    bs=16, device='cpu'
)

_cb = RegularizeCallback(criteria=large_final, granularity='filter', weight=1e-4)
_learn = Learner(_dls, _model, loss_func=nn.CrossEntropyLoss(), cbs=[_cb])
_learn.fit(2)  # verify it runs end-to-end without error

epoch,train_loss,valid_loss,time
0,2.301912,2.272374,00:00
1,2.29903,2.271772,00:00


---

## See Also

- [Sparsifier](../sparse/sparsifier.html) - Apply sparsification after regularization pushes weights to zero
- [Criteria](../core/criteria.html) - Importance measures that can leverage regularized weights
- [SparsifyCallback](../sparse/sparsify_callback.html) - Combine with sparsification for gradual pruning