In [None]:
#| default_exp prune.pruner

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_pruning as tp
from torch_pruning.pruner import function

import pickle
from itertools import cycle
from fastcore.basics import store_attr, listify, true
from fasterai.core.criteria import *
from fastai.vision.all import *


from torch_pruning.pruner.algorithms.scheduler import linear_scheduler
from torch.fx import symbolic_trace

In [None]:
#| include: false
from nbdev.showdoc import *

## Overview

The `Pruner` class provides structured pruning capabilities using the [torch-pruning](https://github.com/VainF/Torch-Pruning) library. Unlike unstructured pruning (which zeros individual weights), structured pruning removes entire filters/channels, resulting in a genuinely smaller and faster model.

**Key Features:**
- Automatic dependency handling across layers
- Support for both local (per-layer) and global (cross-layer) pruning
- Automatic detection and handling of attention layers in transformers
- Compatible with various importance criteria from `fasterai.core.criteria`

### Sparsifier vs Pruner: When to Use Which?

| Aspect | Sparsifier | Pruner |
|--------|------------|--------|
| **What it removes** | Individual weights (unstructured) | Entire filters/channels (structured) |
| **Model size** | Same architecture, sparse weights | Smaller architecture |
| **Speedup** | Requires sparse hardware/libraries | Immediate speedup on any hardware |
| **Accuracy impact** | Generally lower at same sparsity | May need fine-tuning |
| **Best for** | Research, sparse-aware inference | Production deployment |

In [None]:
#| export
from fasterai.core.schedule import Schedule

class Pruner():
    "Structured pruning for neural networks using torch_pruning"
    def __init__(self, model, pruning_ratio, context, criteria, schedule=linear_scheduler, ignored_layers=None, example_inputs=torch.randn(1, 3, 224, 224), *args, **kwargs):
        store_attr()
        self.num_heads = {}
        self._original_params = sum(p.numel() for p in model.parameters())
        if not self.ignored_layers: self.get_ignored_layers(self.model)

        # Handle pruning_ratio as float or dict
        self.pruning_ratio_dict = None
        if isinstance(self.pruning_ratio, dict):
            # Convert name-based dict to module-based dict for torch-pruning
            self.pruning_ratio_dict = self._resolve_pruning_ratio_dict(self.pruning_ratio)
            self.default_pruning_ratio = kwargs.pop('default_pruning_ratio', 0.0)
            print(f"Using per-layer pruning with {len(self.pruning_ratio_dict)} layer-specific ratios")
        else:
            if self.pruning_ratio > 1: self.pruning_ratio = self.pruning_ratio / 100
            if not (0 < self.pruning_ratio <= 1):
                raise ValueError(f"pruning_ratio must be in range (0, 1], got {self.pruning_ratio}")
            self.default_pruning_ratio = self.pruning_ratio

        # Convert Schedule object to torch-pruning compatible function
        tp_schedule = self._to_tp_scheduler(self.schedule)

        self.pruner = tp.pruner.MetaPruner(
            self.model,
            example_inputs=self.example_inputs.to(next(self.model.parameters()).device),
            importance=self.group_importance,
            pruning_ratio=self.default_pruning_ratio,
            pruning_ratio_dict=self.pruning_ratio_dict,
            ignored_layers=self.ignored_layers,
            global_pruning=True if self.context=='global' else False,
            num_heads=self.num_heads,
            iterative_pruning_ratio_scheduler=tp_schedule,
            *args,
            **kwargs
        )

    def _build_pruning_schedule(self, sched_func):
        "Create a schedule function compatible with torch-pruning's Pruner"
        def scheduler(pruning_ratio, steps, start=0, end=1):
            return [
                sched_func(start, end, i / float(steps)) * pruning_ratio
                for i in range(steps + 1)
            ]
        return scheduler

    def _to_tp_scheduler(self, schedule):
        "Convert Schedule object or callable to torch-pruning compatible scheduler"
        # If it's a Schedule object, extract sched_func and build compatible function
        if isinstance(schedule, Schedule):
            return self._build_pruning_schedule(schedule.sched_func)
        # Otherwise assume it's already a compatible callable (like linear_scheduler)
        return schedule

    def _resolve_pruning_ratio_dict(self, ratio_dict):
        "Convert layer name strings to module references for torch-pruning"
        name_to_module = dict(self.model.named_modules())
        resolved = {}
        for key, ratio in ratio_dict.items():
            if isinstance(key, str):
                if key in name_to_module:
                    module = name_to_module[key]
                    # Normalize ratio to 0-1 range
                    resolved[module] = ratio / 100 if ratio > 1 else ratio
                else:
                    print(f"Warning: Layer '{key}' not found in model, skipping")
            elif isinstance(key, nn.Module):
                resolved[key] = ratio / 100 if ratio > 1 else ratio
        return resolved
          
    def prune_model(self):
        "Execute one pruning step and restore attention layer configurations"
        self.pruner.step()
        self.restore_attention_layers()


    def get_linear_layers_to_ignore(self, 
                                    model: nn.Module  # The model to analyze
    ):
        "Find and ignore output Linear layers to preserve model output dimensions"
        try:
            traced = symbolic_trace(model)
            for node in traced.graph.nodes:
                if node.op == "output":  # Identify the output
                    for input_node in node.all_input_nodes:
                        if input_node.target:  # Find the corresponding layer
                            module = dict(model.named_modules()).get(input_node.target)
                            if isinstance(module, torch.nn.Linear):
                                self.ignored_layers.append(module)
                                print(f"Ignoring output layer: {input_node.target}")
        except Exception as e:
            print(f"Could not trace model for output layer detection: {e}")


    def get_attention_layers_to_ignore(self, 
                                       model: nn.Module  # The model to analyze
    ):
        "Find and ignore attention layers (qkv projections) to preserve attention structure"
        for module in model.modules():
            if hasattr(module, 'num_heads'):
                if hasattr(module, 'qkv'):
                    self.ignored_layers.append(module.qkv)
                    self.num_heads[module.qkv] = module.num_heads
                    print(f"Attention layer ignored: {module.qkv}, num_heads={module.num_heads}")
                elif hasattr(module, 'qkv_proj'):
                    self.ignored_layers.append(module.qkv_proj)
                    self.num_heads[module.qkv_proj] = module.num_heads
                    print(f"Attention layer ignored: {module.qkv_proj}, num_heads={module.num_heads}")

    
    def get_ignored_layers(self, 
                           model: nn.Module  # The model to analyze
    ):
        "Build list of layers to ignore during pruning"
        self.ignored_layers = []
        self.get_linear_layers_to_ignore(model)
        self.get_attention_layers_to_ignore(model)
        print(f"Total ignored layers: {len(self.ignored_layers)}")
    
                
    def restore_attention_layers(self):
        "Restore num_heads and head_dim attributes after pruning attention layers"
        for m in self.model.modules():
            if hasattr(m, 'num_heads'):
                if hasattr(m, 'qkv'):
                    m.num_heads = self.num_heads[m.qkv]
                    m.head_dim = m.qkv.out_features // (3 * m.num_heads)
                elif hasattr(m, 'qkv_proj'):
                    m.num_heads = self.num_heads[m.qkv_proj]
                    m.head_dim = m.qkv_proj.out_features // (3 * m.num_heads)


    def group_importance(self, group):
        "Compute importance scores for a dependency group"
        handler_map = {
            function.prune_conv_out_channels: 'filter',
            function.prune_linear_out_channels: 'row',
            function.prune_linear_in_channels: 'column',
            function.prune_conv_in_channels: 'shared_kernel',
        }
    
        group_imp = []
        group_idxs = []
    
        for i, (dep, idxs) in enumerate(group):
            if dep.handler in handler_map:
                impo = self.criteria(dep.target.module, handler_map.get(dep.handler), squeeze=True)
                group_imp.append(impo)
                group_idxs.append(group[i].root_idxs)
    
        if len(group_imp) == 0:
            return torch.tensor([])
            
        reduced_imp = torch.zeros_like(group_imp[0])
    
        for i, (imp, root_idxs) in enumerate(zip(group_imp, group_idxs)):
            imp = imp.to('cpu')
            reduced_imp = reduced_imp.to('cpu')
            reduced_imp.scatter_add_(0, torch.tensor(root_idxs, device=imp.device), imp)
    
        reduced_imp /= len(group_imp)
    
        return reduced_imp.to(default_device())

    def print_sparsity(self) -> None:
        "Print pruning report showing channel counts and parameter reduction"
        total_params = 0
        
        print("\nPruning Report:")
        print("-" * 85)
        print(f"{'Layer':<35} {'Type':<12} {'In Ch':<8} {'Out Ch':<8} {'Params':<12}")
        print("-" * 85)
        
        for name, m in self.model.named_modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                params = sum(p.numel() for p in m.parameters())
                total_params += params
                
                if isinstance(m, nn.Conv2d):
                    in_ch, out_ch = m.in_channels, m.out_channels
                    layer_type = "Conv2d"
                else:
                    in_ch, out_ch = m.in_features, m.out_features
                    layer_type = "Linear"
                
                print(f"{name:<35} {layer_type:<12} {in_ch:<8} {out_ch:<8} {params:<12,}")
        
        print("-" * 85)
        reduction = 100 * (1 - total_params / self._original_params) if self._original_params > 0 else 0
        print(f"{'Total':<35} {'':<12} {'':<8} {'':<8} {total_params:<12,}")
        print(f"{'Original':<35} {'':<12} {'':<8} {'':<8} {self._original_params:<12,}")
        print(f"{'Reduction':<35} {'':<12} {'':<8} {'':<8} {reduction:>10.2f}%")

In [None]:
show_doc(Pruner.prune_model)

---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/prune/pruner.py#L97){target="_blank" style="float:right; font-size:smaller"}

### Pruner.prune_model

```python

def prune_model(
    
):


```

*Execute one pruning step and restore attention layer configurations*

In [None]:
show_doc(Pruner.group_importance)

---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/prune/pruner.py#L159){target="_blank" style="float:right; font-size:smaller"}

### Pruner.group_importance

```python

def group_importance(
    group
):


```

*Compute importance scores for a dependency group*

In [None]:
show_doc(Pruner.print_sparsity)

---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/prune/pruner.py#L191){target="_blank" style="float:right; font-size:smaller"}

### Pruner.print_sparsity

```python

def print_sparsity(
    
)->None:


```

*Print pruning report showing channel counts and parameter reduction*

---

## Usage Examples

Let's try the `Pruner` with a VGG16 model

```python
model = resnet18()
pruner = Pruner(model, 30, 'local', large_final)
pruner.prune_model()
```

In [None]:
#| hide
from fastcore.test import *

def _test_model():
    return nn.Sequential(
        nn.Conv2d(3, 16, 3, padding=1),
        nn.BatchNorm2d(16),
        nn.ReLU(),
        nn.Conv2d(16, 32, 3, padding=1),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.AdaptiveAvgPool2d(1),
        nn.Flatten(),
        nn.Linear(32, 10)
    )

# Pruner construction
model = _test_model()
x = torch.randn(1, 3, 8, 8)
pruner = Pruner(model, 30, 'local', large_final, example_inputs=x)
assert pruner is not None

# Prune model reduces parameter count
params_before = sum(p.numel() for p in model.parameters())
pruner.prune_model()
params_after = sum(p.numel() for p in model.parameters())
assert params_after < params_before

# Model still produces valid output after pruning
out = model(x)
test_eq(out.shape[0], 1)  # batch dim preserved
test_eq(out.shape[1], 10)  # output classes preserved

Ignoring output layer: 8
Total ignored layers: 1


---

## See Also

- [PruneCallback](prune_callback.html) - Apply structured pruning during fastai training
- [Criteria](../core/criteria.html) - Different importance measures for selecting what to prune
- [Schedules](../core/schedules.html) - Control pruning progression during training
- [Sparsifier](../sparse/sparsifier.html) - Unstructured pruning (zeroing weights without removing them)
- [torch-pruning documentation](https://github.com/VainF/Torch-Pruning) - The underlying library used by Pruner