In [1]:
from fasterai.core.criteria import *
from fasterai.core.schedule import *
from fasterai.regularize.all import *
from fastai.vision.all import *

## Overview

**Group Regularization** is a technique that encourages structured sparsity in neural networks during training. Unlike standard L2 regularization (weight decay) which penalizes individual weights, group regularization penalizes *groups* of weights togetherâ€”such as entire filters, kernels, or channels.

### Why Use Group Regularization?

When preparing a model for **structured pruning**, you want entire structures (filters, channels) to become unimportant, not just individual weights. Group regularization pushes these structures toward zero *during training*, making subsequent pruning:

1. **More effective** - Pruned structures are already near-zero, minimizing accuracy loss
2. **Cleaner** - Clear separation between important and unimportant structures
3. **Hardware-friendly** - Structured sparsity maps well to GPU/CPU acceleration

### The RegularizeCallback

The `RegularizeCallback` adds a regularization term to the loss function:

$$\mathcal{L}_{total} = \mathcal{L}_{task} + \lambda \sum_{g \in \text{groups}} \|W_g\|_p$$

Where $\lambda$ is the regularization weight and $W_g$ are weight groups at your chosen granularity.

## 1. Setup and Data

Let's start by loading a dataset and establishing a baseline model without regularization.

In [8]:
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

## 2. Baseline Training (No Regularization)

First, we train a model without any regularization to establish a baseline accuracy.

In [9]:
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

learn.fit_one_cycle(5)

epoch,train_loss,valid_loss,accuracy,time
0,0.682476,0.530806,0.847091,00:02
1,0.403894,0.268916,0.905277,00:02
2,0.235484,0.212882,0.918133,00:03
3,0.119361,0.198808,0.921516,00:03
4,0.067766,0.18581,0.928958,00:03


## 3. Training with Group Regularization

Now let's train with `RegularizeCallback`. We'll configure it with:

- **`criteria=squared_final`**: Uses squared weight magnitudes for regularization
- **`granularity='weight'`**: Regularizes at individual weight level (try `'filter'` for structured pruning prep)
- **`weight=3e-5`**: Regularization strength (higher = more aggressive)
- **`schedule=one_cycle`**: Varies regularization strength during training

In [11]:
reg_cb = RegularizeCallback(squared_final, 'weight', 1e-3, schedule=one_cycle)

In [12]:
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In [13]:
learn.fit_one_cycle(5, cbs=reg_cb)

epoch,train_loss,valid_loss,accuracy,time
0,0.785396,1.031241,0.818674,00:02
1,1.362821,2.267276,0.893099,00:02
2,3.505133,4.478498,0.891069,00:02
3,4.259987,4.4061,0.927605,00:02
4,4.243036,4.337528,0.94046,00:02


## 4. Comparing Results

After training, you should observe:
- Similar or slightly lower accuracy (regularization adds a constraint)
- Weights that are more concentrated around zero
- Cleaner weight distribution for subsequent pruning

To visualize the effect, you can plot weight histograms:

```python
import matplotlib.pyplot as plt

# Get all conv weights
weights = torch.cat([m.weight.data.flatten() for m in learn.model.modules() 
                     if isinstance(m, nn.Conv2d)])

plt.hist(weights.cpu().numpy(), bins=100, alpha=0.7)
plt.xlabel('Weight Value')
plt.ylabel('Count')
plt.title('Weight Distribution After Group Regularization')
plt.show()
```

## 5. Parameter Guide

### Choosing Granularity

| Granularity | Effect | Best For |
|-------------|--------|----------|
| `'weight'` | Regularizes individual weights | Unstructured pruning, general sparsity |
| `'filter'` | Regularizes entire Conv2d filters | **Structured pruning** (recommended) |
| `'kernel'` | Regularizes 2D kernels within filters | Moderate structure |
| `'channel'` | Regularizes input channels | Channel pruning |

### Choosing Regularization Weight

| Weight Range | Effect |
|--------------|--------|
| `1e-6 - 1e-5` | Very light regularization, minimal accuracy impact |
| `1e-5 - 1e-4` | Moderate regularization, good balance |
| `1e-4 - 1e-3` | Strong regularization, may reduce accuracy |
| `> 1e-3` | Very aggressive, use with caution |

**Tip:** Start with `1e-5` and increase if weights don't concentrate toward zero.

### Recommended Workflow

```python
# 1. Train with filter-level regularization
reg_cb = RegularizeCallback(
    criteria=large_final,      # or squared_final
    granularity='filter',      # for structured pruning
    weight=1e-4,
    schedule=one_cycle,
    verbose=True
)
learn.fit(epochs, cbs=[reg_cb])

# 2. Prune the regularized model
from fasterai.prune.all import *
pruner = Pruner(learn.model, sparsity=0.3, context='local', criteria=large_final)
pruner.prune_model()

# 3. Fine-tune
learn.fit(fine_tune_epochs)
```

## Summary

| Concept | Description |
|---------|-------------|
| **Group Regularization** | Penalizes groups of weights to encourage structured sparsity |
| **RegularizeCallback** | fastai callback that adds regularization term to loss |
| **Granularity** | Level at which to group weights (`'weight'`, `'filter'`, `'kernel'`) |
| **Schedule** | Varies regularization strength during training |
| **Typical Use** | Pre-pruning preparation to make structured pruning more effective |

---

## See Also

- [Criteria](../../core/criteria.html) - Importance measures used for regularization
- [Schedules](../../core/schedules.html) - Control regularization strength over training
- [Pruner](../../prune/pruner.html) - Apply structured pruning after regularization
- [Sparsifier](../../sparse/sparsifier.html) - Apply unstructured sparsification