## Overview

The `Pruner` class performs **structured pruning** - physically removing entire filters and channels from your neural network. Unlike sparsification (which zeros weights but keeps the architecture), structured pruning creates a genuinely smaller model that runs faster on standard hardware.

### Sparsifier vs Pruner

| Aspect | Sparsifier | Pruner |
|--------|------------|--------|
| **What it removes** | Individual weights → zeros | Entire filters → gone |
| **Architecture** | Unchanged (same shapes) | Smaller (fewer channels) |
| **Speedup** | Needs sparse hardware | Immediate on any hardware |
| **Use case** | Research, sparse accelerators | Production deployment |

**When to use Pruner:**
- You need a smaller model file
- You want faster inference without special hardware
- You're deploying to edge devices or mobile

In [None]:
import torch
import torch.nn as nn
from torchvision.models import resnet18
from fasterai.prune.pruner import Pruner
from fasterai.core.criteria import large_final, random

## 1. Basic Pruning

Let's start with a ResNet18 and prune 30% of its filters:

In [None]:
model = resnet18(weights=None)

print('Before pruning:')
print(f'  conv1: {model.conv1}')
print(f'  layer1[0].conv1: {model.layer1[0].conv1}')
print(f'  Parameters: {sum(p.numel() for p in model.parameters()):,}')

Before pruning:
  conv1: Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  layer1[0].conv1: Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  Parameters: 11,689,512


In [None]:
pruner = Pruner(
    model, 
    pruning_ratio=30,      # Remove 30% of filters
    context='local',       # Prune each layer independently
    criteria=large_final   # Keep filters with largest weights
)
pruner.prune_model()

print('\nAfter pruning:')
print(f'  conv1: {model.conv1}')
print(f'  layer1[0].conv1: {model.layer1[0].conv1}')
params_after = sum(p.numel() for p in model.parameters())
print(f'  Parameters: {params_after:,}')
print(f'  Reduction: {100*(1 - params_after/11689512):.1f}%')

Ignoring output layer: fc
Total ignored layers: 1

After pruning:
  conv1: Conv2d(3, 44, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  layer1[0].conv1: Conv2d(44, 44, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  Parameters: 5,820,556
  Reduction: 50.2%


Notice the channel counts changed: `Conv2d(3, 64, ...)` became `Conv2d(3, 44, ...)`. The model is genuinely smaller!

**Key point:** The Pruner automatically handles layer dependencies. When you remove output channels from one layer, it removes the corresponding input channels from the next layer.

## 2. Local vs Global Pruning

The `context` parameter controls how filters are selected for pruning:

| Context | Behavior | Best for |
|---------|----------|----------|
| `'local'` | Each layer loses same % of filters | Uniform compression |
| `'global'` | Compare importance across all layers | Maximum accuracy retention |

In [None]:
# Local: each layer pruned independently to 50%
model_local = resnet18(weights=None)
pruner = Pruner(model_local, 50, 'local', large_final)
pruner.prune_model()

print('\nLocal pruning (each layer loses 50%):')
print(f'  layer1[0].conv1: Conv2d({model_local.layer1[0].conv1.in_channels}, {model_local.layer1[0].conv1.out_channels}, ...)')
print(f'  layer2[0].conv1: Conv2d({model_local.layer2[0].conv1.in_channels}, {model_local.layer2[0].conv1.out_channels}, ...)')
print(f'  layer3[0].conv1: Conv2d({model_local.layer3[0].conv1.in_channels}, {model_local.layer3[0].conv1.out_channels}, ...)')
print(f'  layer4[0].conv1: Conv2d({model_local.layer4[0].conv1.in_channels}, {model_local.layer4[0].conv1.out_channels}, ...)')

Ignoring output layer: fc
Total ignored layers: 1

Local pruning (each layer loses 50%):
  layer1[0].conv1: Conv2d(32, 32, ...)
  layer2[0].conv1: Conv2d(32, 64, ...)
  layer3[0].conv1: Conv2d(64, 128, ...)
  layer4[0].conv1: Conv2d(128, 256, ...)


In [None]:
# Global: least important filters across entire network
model_global = resnet18(weights=None)
pruner = Pruner(model_global, 50, 'global', large_final)
pruner.prune_model()

print('\nGlobal pruning (importance compared across layers):')
print(f'  layer1[0].conv1: Conv2d({model_global.layer1[0].conv1.in_channels}, {model_global.layer1[0].conv1.out_channels}, ...)')
print(f'  layer2[0].conv1: Conv2d({model_global.layer2[0].conv1.in_channels}, {model_global.layer2[0].conv1.out_channels}, ...)')
print(f'  layer3[0].conv1: Conv2d({model_global.layer3[0].conv1.in_channels}, {model_global.layer3[0].conv1.out_channels}, ...)')
print(f'  layer4[0].conv1: Conv2d({model_global.layer4[0].conv1.in_channels}, {model_global.layer4[0].conv1.out_channels}, ...)')

Ignoring output layer: fc
Total ignored layers: 1

Global pruning (importance compared across layers):
  layer1[0].conv1: Conv2d(64, 64, ...)
  layer2[0].conv1: Conv2d(64, 128, ...)
  layer3[0].conv1: Conv2d(128, 69, ...)
  layer4[0].conv1: Conv2d(256, 512, ...)


With global pruning, early layers often keep more filters (they're more important) while later layers with redundant features get pruned more aggressively.

## 3. Iterative Pruning

For high compression ratios, iterative pruning works better than one-shot. The model gradually adapts to having fewer parameters:

In [None]:
model = resnet18(weights=None)
params_orig = sum(p.numel() for p in model.parameters())

# Iterative pruning: 5 steps to reach 50% pruning
pruner = Pruner(
    model, 
    pruning_ratio=50,
    context='local', 
    criteria=large_final,
    iterative_steps=5  # Spread pruning over 5 steps
)

print('Iterative pruning (5 steps to reach 50%):')
for i in range(5):
    pruner.prune_model()
    params = sum(p.numel() for p in model.parameters())
    print(f'  Step {i+1}: {params:,} params ({100*(1-params/params_orig):.1f}% reduction)')

Ignoring output layer: fc
Total ignored layers: 1
Iterative pruning (5 steps to reach 50%):
  Step 1: 9,481,588 params (18.9% reduction)
  Step 2: 7,534,380 params (35.5% reduction)
  Step 3: 5,820,556 params (50.2% reduction)
  Step 4: 4,318,898 params (63.1% reduction)
  Step 5: 3,055,880 params (73.9% reduction)


**In practice:** When using `PruneCallback` during training, iterative pruning happens automatically - the model is pruned a little bit after each batch, allowing it to recover between steps.

## 4. Per-Layer Pruning Ratios

Different layers have different sensitivity to pruning. You can specify custom ratios using a dictionary:

In [None]:
model = resnet18(weights=None)

# Conservative on early layers, aggressive on later layers
per_layer_ratios = {
    'layer1.0.conv1': 20,  'layer1.0.conv2': 20,  # 20% pruning
    'layer2.0.conv1': 40,  'layer2.0.conv2': 40,  # 40% pruning  
    'layer3.0.conv1': 60,  'layer3.0.conv2': 60,  # 60% pruning
    'layer4.0.conv1': 80,  'layer4.0.conv2': 80,  # 80% pruning
}

pruner = Pruner(model, per_layer_ratios, 'local', large_final)
pruner.prune_model()

print('\nPer-layer pruning results:')
print(f'  layer1.0.conv1: {model.layer1[0].conv1.out_channels} channels (20% pruned from 64)')
print(f'  layer2.0.conv1: {model.layer2[0].conv1.out_channels} channels (40% pruned from 128)')
print(f'  layer3.0.conv1: {model.layer3[0].conv1.out_channels} channels (60% pruned from 256)')
print(f'  layer4.0.conv1: {model.layer4[0].conv1.out_channels} channels (80% pruned from 512)')

Ignoring output layer: fc
Total ignored layers: 1
Using per-layer pruning with 8 layer-specific ratios

Per-layer pruning results:
  layer1.0.conv1: 51 channels (20% pruned from 64)
  layer2.0.conv1: 76 channels (40% pruned from 128)
  layer3.0.conv1: 102 channels (60% pruned from 256)
  layer4.0.conv1: 102 channels (80% pruned from 512)


**Tip:** Use sensitivity analysis to determine which layers can tolerate more pruning. See the [Sensitivity Tutorial](../analyze/sensitivity.html) for details.

## 5. Verifying the Pruned Model

After pruning, the model remains fully functional - it just has fewer parameters:

In [None]:
model = resnet18(weights=None)
model.eval()

# Prune 50%
pruner = Pruner(model, 50, 'global', large_final)
pruner.prune_model()

# Verify forward pass works
x = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    output = model(x)

print('\nForward pass verification:')
print(f'  Input shape:  {x.shape}')
print(f'  Output shape: {output.shape}')
print('  Model works correctly after pruning!')

Ignoring output layer: fc
Total ignored layers: 1

Forward pass verification:
  Input shape:  torch.Size([1, 3, 224, 224])
  Output shape: torch.Size([1, 1000])
  Model works correctly after pruning!


## 6. Importance Criteria

The `criteria` parameter determines how filter importance is calculated:

| Criteria | Method | Best for |
|----------|--------|----------|
| `large_final` | Keep filters with largest L1 norm | General use, most common |
| `small_final` | Keep filters with smallest L1 norm | Unusual, for experimentation |
| `random` | Random selection | Baseline comparison |

In [None]:
results = {}
for name, criteria in [('large_final', large_final), ('random', random)]:
    model = resnet18(weights=None)
    pruner = Pruner(model, 30, 'local', criteria)
    pruner.prune_model()
    results[name] = sum(p.numel() for p in model.parameters())

print('\nSame pruning ratio, different criteria:')
for name, params in results.items():
    print(f'  {name}: {params:,} parameters')
print('\nNote: Parameter counts are similar, but accuracy differs!')
print('large_final preserves important filters, random does not.')

Ignoring output layer: fc
Total ignored layers: 1
Ignoring output layer: fc
Total ignored layers: 1

Same pruning ratio, different criteria:
  large_final: 5,820,556 parameters
  random: 5,820,556 parameters

Note: Parameter counts are similar, but accuracy differs!
large_final preserves important filters, random does not.


## Summary

| Feature | Description |
|---------|-------------|
| **Structured pruning** | Removes entire filters, creating genuinely smaller models |
| **Local context** | Each layer pruned by same percentage |
| **Global context** | Compare importance across all layers |
| **Iterative pruning** | Gradual pruning for better accuracy retention |
| **Per-layer ratios** | Dictionary of custom ratios per layer |
| **Auto dependency** | Handles layer connections automatically |

### Typical Workflow

```python
# 1. One-shot pruning for quick experiments
pruner = Pruner(model, 30, 'local', large_final)
pruner.prune_model()

# 2. During training with PruneCallback (recommended)
cb = PruneCallback(pruning_ratio=50, schedule=agp, context='global', criteria=large_final)
learn.fit(10, cbs=[cb])
```

---

## See Also

- [PruneCallback Tutorial](prune_callback.html) - Apply pruning during fastai training
- [Sparsifier Tutorial](../sparse/sparsifier.html) - Unstructured pruning alternative
- [Criteria](../../core/criteria.html) - Importance measures for filter selection
- [Schedules](../../core/schedules.html) - Control pruning progression