In [None]:
#| include: false
from nbdev.showdoc import *
import warnings
warnings.filterwarnings('ignore')

In [None]:
#| include: false
from fastai.vision.all import *
from fastai.callback.all import *
from fasterai.prune.all import *
from fasterai.core.criteria import *
import torch_pruning as tp
from torch_pruning.pruner import function
import torch_pruning as tp

import torch
import torch.nn as nn
import torch.nn.functional as F

## Overview

**Structured Pruning** removes entire filters, channels, or layers from neural networks, resulting in genuinely smaller and faster models. Unlike sparsification (which zeros individual weights), pruning physically removes parameters.

### Why Use Structured Pruning?

| Approach | What's Removed | Model Size | Speed Benefit | Hardware |
|----------|----------------|------------|---------------|----------|
| Sparsification | Individual weights | Same | Requires sparse support | Specialized |
| **Structured Pruning** | Entire filters | **Smaller** | **Immediate** | Standard |

### Key Benefits

- **Real speedup** - Fewer parameters = faster inference on any hardware
- **Smaller models** - Reduced memory footprint for deployment
- **Gradual pruning** - Remove filters progressively during training
- **Flexible targeting** - Global or local pruning strategies

## 1. Setup and Baseline

In [None]:
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

First, train a baseline ResNet-18 to establish expected performance:

In [None]:
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit_one_cycle(1)

epoch,train_loss,valid_loss,accuracy,time
0,0.57391,0.346901,0.848444,00:02


In [None]:
base_macs, base_params = tp.utils.count_ops_and_params(learn.model, torch.randn(1,3,224,224).to(default_device()))

## 2. Training with PruneCallback

Now let's train with gradual filter pruning. We'll remove 40% of filters using a one-cycle schedule:

Configuration:
- **`pruning_ratio=40`** - Remove 40% of filters
- **`context='global'`** - Remove least important filters from anywhere in the network
- **`criteria=large_final`** - Keep filters with largest final weights
- **`schedule=one_cycle`** - Gradually increase pruning following one-cycle pattern

In [None]:
pr_cb = PruneCallback(pruning_ratio=0.4, context='global', criteria=large_final, schedule=one_cycle)
learn.fit_one_cycle(10, cbs=pr_cb)

Ignoring output layer: 1.8
Total ignored layers: 1


epoch,train_loss,valid_loss,accuracy,time
0,0.350107,0.277946,0.875507,00:02
1,0.270526,0.309414,0.881597,00:03
2,0.247778,0.240875,0.903924,00:03
3,0.224332,0.608088,0.70839,00:03
4,0.193209,0.22106,0.897835,00:03
5,0.249345,0.259771,0.895805,00:04
6,0.266264,0.265805,0.890392,00:04
7,0.234256,0.263015,0.888363,00:02
8,0.224429,0.255041,0.890392,00:02
9,0.196133,0.255395,0.892422,00:03


Sparsity at the end of epoch 0: 0.39%
Sparsity at the end of epoch 1: 1.54%
Sparsity at the end of epoch 2: 5.60%
Sparsity at the end of epoch 3: 15.91%
Sparsity at the end of epoch 4: 29.13%
Sparsity at the end of epoch 5: 36.64%
Sparsity at the end of epoch 6: 39.12%
Sparsity at the end of epoch 7: 39.79%
Sparsity at the end of epoch 8: 39.96%
Sparsity at the end of epoch 9: 40.00%


In [None]:
pruned_macs, pruned_params = tp.utils.count_ops_and_params(learn.model, torch.randn(1,3,224,224).to(default_device()))

## 3. Measuring Compression

The pruned model has fewer parameters and requires less compute:

In [None]:
print(f'The pruned model has {pruned_macs/base_macs:.2f} the compute of original model')

The pruned model has 0.63 the compute of original model


In [None]:
print(f'The pruned model has {pruned_params/base_params:.2f} the parameters of original model')

The pruned model has 0.18 the parameters of original model


## Summary

| Metric | Original | Pruned (40%) | Improvement |
|--------|----------|--------------|-------------|
| Parameters | 100% | ~18% | **5.5x smaller** |
| Compute (MACs) | 100% | ~63% | **1.6x fewer ops** |
| Accuracy | Baseline | ~1% drop | Minimal impact |

### Parameter Reference

| Parameter | Description | Example |
|-----------|-------------|---------|
| `pruning_ratio` | Percentage of filters to remove | `40` |
| `context` | Pruning scope | `'global'` (whole model) or `'local'` (per-layer) |
| `criteria` | Importance measure | `large_final`, `magnitude`, `taylor` |
| `schedule` | How pruning increases over training | `one_cycle`, `cos`, `linear` |

---

## See Also

- [Pruner](../../prune/pruner.html) - Lower-level API for one-shot pruning
- [Sparsifier](../../sparse/sparsifier.html) - For unstructured sparsification
- [Schedules](../../core/schedules.html) - Available pruning schedules
- [Criteria](../../core/criteria.html) - Filter importance measures
- [YOLO Pruning Tutorial](YOLOV8.html) - Pruning detection models