In [None]:
#| default_exp sparse.sparsifier

In [None]:
#| include: false
from nbdev.showdoc import *
from fastai.vision.all import *

In [None]:
#| export
from __future__ import annotations
import torch
import torch.nn as nn
import pickle
from fastcore.basics import store_attr, true
from typing import Callable, Type
from fasterai.core.criteria import *
from einops import rearrange

## Overview

A sparse vector, as opposed to a dense one, is a vector which contains a lot of zeroes. When we speak about making a neural network sparse, we thus mean that the network's weights are mostly zeroes.

With fasterai, you can do that thanks to the `Sparsifier` class.

In [None]:
#| export
class Sparsifier():
    "Class providing sparsifying capabilities"
    def __init__(self, 
                 model: nn.Module,                        # The model to sparsify
                 granularity: str,                        # Granularity of sparsification (e.g., 'weight', 'filter')
                 context: str,                            # Context for sparsification ('global' or 'local')
                 criteria: Criteria,                      # Criteria to determine which weights to keep
                 nm: bool = False,                        # Whether to use N:M sparsity pattern (forces 2:4 sparsity)
                 layer_type: Type[nn.Module] = nn.Conv2d  # Type of layers to apply sparsification to
    ):
        if nm: print('Sparsity automatically set to 50% with 2:4 pattern')
        store_attr()
        self._save_weights()
        self._reset_threshold()

    def _iter_layers(self, 
                     filter_type: str = 'layer_type',       # Filter: 'layer_type' or 'has_weight'
                     model: nn.Module | None = None         # Model to iterate (default: self.model)
    ):
        "Iterate over model modules with filtering"
        model = model or self.model
        for m in model.modules():
            if filter_type == 'layer_type' and isinstance(m, self.layer_type):
                yield m
            elif filter_type == 'has_weight' and hasattr(m, 'weight'):
                yield m

    def _iter_named_layers(self):
        "Iterate over matching layers with their names"
        for name, m in self.model.named_modules():
            if isinstance(m, self.layer_type):
                yield name, m

    def _to_sparsity_dict(self, 
                          sparsity: float | dict  # Sparsity value or per-layer dict
    ) -> dict:
        "Convert any sparsity input to a {module: sparsity} dict"
        name_to_module = dict(self.model.named_modules())
        
        # Float: apply same sparsity to all layers
        if isinstance(sparsity, (int, float)):
            if not (0 <= sparsity <= 100):
                raise ValueError(f"sparsity must be in range [0, 100], got {sparsity}")
            return {m: sparsity for m in self._iter_layers()}
        
        # Dict: resolve names to modules
        if isinstance(sparsity, dict):
            resolved = {}
            for key, sp in sparsity.items():
                if not (0 <= sp <= 100):
                    raise ValueError(f"sparsity must be in range [0, 100], got {sp}")
                if isinstance(key, str):
                    if key in name_to_module:
                        resolved[name_to_module[key]] = sp
                    else:
                        print(f"Warning: Layer '{key}' not found in model, skipping")
                elif isinstance(key, nn.Module):
                    resolved[key] = sp
            return resolved
        
        raise TypeError(f"sparsity must be float or dict, got {type(sparsity)}")

    def sparsify_layer(self, 
                       m: nn.Module,              # The layer to sparsify
                       sparsity: float,           # Target sparsity level (percentage)
                       round_to: int | None = None  # Round to a multiple of this value
    ) -> None:
        "Apply sparsification to a single layer"
        if not (0 <= sparsity <= 100):
            raise ValueError(f"sparsity must be in range [0, 100], got {sparsity}")
        scores    = self._compute_scores(m, sparsity)
        threshold = self._compute_threshold(scores, sparsity, round_to)
        mask      = self._compute_mask(scores, threshold)
        m.register_buffer('_mask', mask)
        self._apply(m)
        self.criteria.update_weights(m)

    def sparsify_model(self, 
                       sparsity: float | dict,        # Target sparsity level or per-layer dict
                       round_to: int | None = None    # Round to a multiple of this value
    ) -> None:
        "Apply sparsification to all matching layers in the model"
        self._reset_threshold()
        
        # Validate context for non-uniform sparsity
        if isinstance(sparsity, dict) and self.context == 'global':
            raise ValueError("Dict-based sparsity requires 'local' context")
        
        # Convert to unified dict format
        sparsity_map = self._to_sparsity_dict(sparsity)
        
        # Single iteration loop for all cases
        mods = list(self.model.modules())
        for name, m in self._iter_named_layers():
            if m not in sparsity_map:
                continue
            sp = sparsity_map[m]
            self.sparsify_layer(m, sp, round_to)
            # Handle batch norm if present
            mod_idx = mods.index(m)
            if mod_idx + 1 < len(mods) and isinstance(mods[mod_idx + 1], nn.modules.batchnorm._BatchNorm):
                self.sparsify_batchnorm(m, mods[mod_idx + 1])
                
    def sparsify_batchnorm(self, 
                          m: nn.Module,       # The layer before batch norm
                          bn: nn.Module       # The batch norm layer
    ) -> None:
        "Apply filter pruning to batch norm parameters if appropriate"
        mask = getattr(m, "_mask", None)
        if self.granularity == 'filter' and true(mask):
            bn.weight.data.mul_(mask.squeeze())
            bn.bias.data.mul_(mask.squeeze())
            
    def _apply_masks(self) -> None:
        "Apply all stored masks to model weights"
        for m in self._iter_layers():
            self._apply(m)
        
    def _apply(self, 
              m: nn.Module  # Module to apply mask to
    ) -> None:
        "Apply mask to a module's weights"
        mask = getattr(m, "_mask", None)
        if true(mask): m.weight.data.mul_(mask)
        if self.granularity == 'filter' and true(m.bias):
            if true(mask): m.bias.data.mul_(mask.squeeze())
    
    def _reset_weights(self, 
                      model: nn.Module | None = None  # Model to reset (default: self.model)
    ) -> None:
        "Reset weights to their initial values"
        model = model or self.model
        for m in self._iter_layers('has_weight', model):
            init_weights = getattr(m, "_init_weights", m.weight)
            init_biases = getattr(m, "_init_biases", m.bias)
            with torch.no_grad():
                if true(m.weight): m.weight.copy_(init_weights)
                if true(m.bias): m.bias.copy_(init_biases)
            self._apply(m)
            if isinstance(m, nn.modules.batchnorm._BatchNorm): m.reset_parameters()
                
    def _save_weights(self) -> None:
        "Save initial weights of the model"
        for m in self._iter_layers('has_weight'):
            m.register_buffer("_init_weights", m.weight.clone())
            bias = getattr(m, 'bias', None)
            if true(bias): m.register_buffer("_init_biases", bias.clone())
                    
    def save_model(self, 
                  path: str,                            # Path to save the model
                  model: nn.Module | None = None        # Model to save (default: self.model)
    ) -> None:
        "Save model without sparsification buffers"
        model = model or self.model
        tmp_model = copy.deepcopy(model)
        self._reset_weights(tmp_model)
        self._clean_buffers(tmp_model)
        torch.save(tmp_model, path)

    def _clean_buffers(self, 
                      model: nn.Module | None = None  # Model to clean (default: self.model)
    ) -> None:
        "Remove internal buffers used for sparsification"
        model = model or self.model
        for m in self._iter_layers('has_weight', model):
            if hasattr(m, '_mask'): del m._buffers["_mask"]
            if hasattr(m, '_init_weights'): del m._buffers["_init_weights"]
            if hasattr(m, '_init_biases'): del m._buffers["_init_biases"]
                    
    def _reset_threshold(self) -> None:
        "Reset the threshold used for global pruning"
        self.threshold = None
            
    def _rounded_sparsity(self, 
                         n_to_prune: int,  # Number of elements to prune
                         round_to: int     # Rounding value
    ) -> int:
        "Round the number of elements to keep to a multiple of round_to"
        if round_to == 0:
            raise ValueError("round_to must be non-zero")
        return max(round_to * torch.ceil(n_to_prune / round_to), round_to)
    
    def _compute_scores(self, 
                       m: nn.Module,   # Module to compute scores for
                       sparsity: float # Target sparsity level
    ) -> torch.Tensor:
        "Compute importance scores for weights based on criteria"
        return self.criteria(m, self.granularity)
                
    def _compute_threshold(self, 
                          scores: torch.Tensor,  # Importance scores
                          sparsity: float,       # Target sparsity level
                          round_to: int | None   # Rounding value
    ) -> torch.Tensor:
        "Compute threshold for pruning, with optional rounding"
        if self.context == 'global':
            if self.threshold is None: 
                global_scores = torch.cat([self.criteria(m, self.granularity).view(-1) for m in self._iter_layers()])
                self.threshold = torch.quantile(global_scores.view(-1), sparsity / 100)   
        elif self.context == 'local': 
            self.threshold = torch.quantile(scores.view(-1), sparsity / 100)
        else: 
            raise ValueError(f'Invalid context: {self.context}. Must be "global" or "local"')
            
        if round_to:
            n_to_keep = sum(scores.ge(self.threshold)).squeeze()
            self.threshold = torch.topk(scores.squeeze(), int(self._rounded_sparsity(n_to_keep, round_to)))[0].min()
        return self.threshold
    
    def _compute_mask(self, 
                     scores: torch.Tensor,   # Importance scores
                     threshold: torch.Tensor # Threshold for pruning
    ) -> torch.Tensor:
        "Compute binary mask for weights based on scores and threshold"
        if self.nm: return self._apply_nm_sparsity(scores)
        if threshold > scores.max(): threshold = scores.max()
        return scores.ge(threshold).to(dtype=scores.dtype)

    def _apply_nm_sparsity(self, 
                          scores: torch.Tensor  # Importance scores
    ) -> torch.Tensor:
        "Apply 2:4 structured sparsity pattern (N:M sparsity where N=2, M=4)"
        out_channels, in_channels, kernel_height, kernel_width = scores.shape
    
        if in_channels % 4 != 0 or in_channels * kernel_height * kernel_width % 16 != 0:
            print(f"Skipping 2:4 sparsity, Cin * Kh * Kw is not a multiple of 16")
            return torch.ones_like(scores)
    
        blocked_scores = rearrange(scores, 'o (b c) h w -> h w o b c', c=4)
        threshold = blocked_scores.topk(k=2, dim=-1).values[..., -1:]
        mask = (blocked_scores >= threshold).float()
        return rearrange(mask, 'h w o b c -> o (b c) h w')

    def print_sparsity(self) -> None:
        "Print sparsity report for all layers"
        total_params = 0
        total_zeros = 0
        
        print("\nSparsity Report:")
        print("-" * 80)
        print(f"{'Layer':<30} {'Type':<15} {'Params':<10} {'Zeros':<10} {'Sparsity':<10}")
        print("-" * 80)
        
        for name, m in self._iter_named_layers():
            zeros = torch.sum(m.weight == 0).item()
            total = m.weight.nelement()
            sparsity_pct = 100.0 * zeros / total if total > 0 else 0
            
            print(f"{name:<30} {m.__class__.__name__:<15} "
                  f"{total:<10,d} {zeros:<10,d} {sparsity_pct:>8.2f}%")
            
            total_params += total
            total_zeros += zeros
        
        print("-" * 80)
        overall_sparsity = 100.0 * total_zeros / total_params if total_params > 0 else 0
        print(f"{'Overall':<30} {'all':<15} {total_params:<10,d} "
              f"{total_zeros:<10,d} {overall_sparsity:>8.2f}%")

In [None]:
show_doc(Sparsifier)

Found permutation search CUDA kernels
[ASP][Info] permutation_search_kernels can be imported.


---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/sparse/sparsifier.py#L17){target="_blank" style="float:right; font-size:smaller"}

### Sparsifier

```python

def Sparsifier(
    model:nn.Module, # The model to sparsify
    granularity:str, # Granularity of sparsification (e.g., 'weight', 'filter')
    context:str, # Context for sparsification ('global' or 'local')
    criteria:Criteria, # Criteria to determine which weights to keep
    nm:bool=False, # Whether to use N:M sparsity pattern (forces 2:4 sparsity)
    layer_type:Type[nn.Module]=<class 'torch.nn.modules.conv.Conv2d'>, # Type of layers to apply sparsification to
):


```

*Class providing sparsifying capabilities*

The `Sparsifier` class allows us to remove some weights, that are considered to be less useful than others. This can be done by first creating an instance of the class, specifying:

- The `granularity`, i.e. the part of filters that you want to remove. Typically, we usually remove weights, vectors, kernels or even complete filters.
- The `context`, i.e. if you want to consider each layer independently (`local`), or compare the parameters to remove across the whole network (`global`).
- The `criteria`, i.e. the way to assess the usefulness of a parameter. Common methods compare parameters using their magnitude, the lowest magnitude ones considered to be less useful.

---

## Key Methods

User can pass a single layer to sparsify by using the `Sparsifier.sparsify_layer` method.

In [None]:
show_doc(Sparsifier.sparsify_layer)

---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/sparse/sparsifier.py#L79){target="_blank" style="float:right; font-size:smaller"}

### Sparsifier.sparsify_layer

```python

def sparsify_layer(
    m:nn.Module, # The layer to sparsify
    sparsity:float, # Target sparsity level (percentage)
    round_to:int | None=None, # Round to a multiple of this value
)->None:


```

*Apply sparsification to a single layer*

---

Most of the time, we may want to sparsify the whole model at once, using the `Sparsifier.sparsify_model` method, indicating the percentage of sparsity you want to apply.

In [None]:
show_doc(Sparsifier.sparsify_model)

---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/sparse/sparsifier.py#L94){target="_blank" style="float:right; font-size:smaller"}

### Sparsifier.sparsify_model

```python

def sparsify_model(
    sparsity:float | dict, # Target sparsity level or per-layer dict
    round_to:int | None=None, # Round to a multiple of this value
)->None:


```

*Apply sparsification to all matching layers in the model*

---

## Advanced Options

In some case, you may want to impose the remaining amount of parameters to be a multiple of a given number (e.g. 8), this can be done by passing the `round_to` parameter.

Instead of passing a single value of sparsity, a dictionary of per-layer sparsities can be provided. This allows fine-grained control over which layers get sparsified and by how much.

**Example**: Apply different sparsity levels to specific layers:

```python
sparsity_levels = {
    'conv1': 30,           # 30% sparsity on first conv
    'layer1.0.conv1': 50,  # 50% sparsity
    'layer2.0.conv1': 70,  # 70% sparsity (more aggressive)
}
sparsifier.sparsify_model(sparsity=sparsity_levels)
```

This works seamlessly with `SensitivityResult.to_layer_targets()` to apply sensitivity-aware compression.

In [None]:
#| hide
from fastcore.test import *

def _test_model():
    return nn.Sequential(
        nn.Conv2d(3, 16, 3, padding=1),
        nn.BatchNorm2d(16),
        nn.ReLU(),
        nn.Conv2d(16, 32, 3, padding=1),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.AdaptiveAvgPool2d(1),
        nn.Flatten(),
        nn.Linear(32, 10)
    )

# Sparsify single layer at 50%
conv = nn.Conv2d(3, 16, 3)
sp = Sparsifier(nn.Sequential(conv), 'weight', 'local', large_final, layer_type=nn.Conv2d)
sp.sparsify_layer(conv, 50)
actual = (conv.weight == 0).float().mean().item() * 100
test_close(actual, 50.0, eps=5.0)

# Buffers created
assert hasattr(conv, '_mask')
assert hasattr(conv, '_init_weights')

# Clean buffers
sp._clean_buffers()
assert not hasattr(conv, '_mask')

# sparsify_model with float
model = _test_model()
sp2 = Sparsifier(model, 'weight', 'local', large_final)
sp2.sparsify_model(30)
for m in model.modules():
    if isinstance(m, nn.Conv2d):
        assert (m.weight == 0).any()  # some sparsification applied

# Dict + context='global' raises ValueError
model2 = _test_model()
sp_g = Sparsifier(model2, 'weight', 'global', large_final)
with ExceptionExpected(ValueError):
    sp_g.sparsify_model({'0': 30, '3': 60})

# Invalid sparsity range
model3 = _test_model()
sp3 = Sparsifier(model3, 'weight', 'local', large_final)
conv3 = nn.Conv2d(3, 16, 3)
with ExceptionExpected(ValueError): sp3.sparsify_layer(conv3, 150)
with ExceptionExpected(ValueError): sp3.sparsify_layer(conv3, -10)

# print_sparsity runs without error
model4 = _test_model()
sp4 = Sparsifier(model4, 'weight', 'local', large_final)
sp4.sparsify_model(50)
sp4.print_sparsity()


Sparsity Report:
--------------------------------------------------------------------------------
Layer                          Type            Params     Zeros      Sparsity  
--------------------------------------------------------------------------------
0                              Conv2d          432        216           50.00%
3                              Conv2d          4,608      2,304         50.00%
--------------------------------------------------------------------------------
Overall                        all             5,040      2,520         50.00%


In [None]:
#| hide
#| slow
from torchvision.models import resnet18
model_lg = resnet18(weights=None)
sp_lg = Sparsifier(model_lg, 'weight', 'local', large_final)
sp_lg.sparsify_model(60)
total_zeros = sum((m.weight==0).sum().item() for m in model_lg.modules() if isinstance(m, nn.Conv2d))
total_params = sum(m.weight.numel() for m in model_lg.modules() if isinstance(m, nn.Conv2d))
test_close(100*total_zeros/total_params, 60.0, eps=10.0)

---

## See Also

- [SparsifyCallback](sparsify_callback.html) - Apply sparsification during fastai training
- [Criteria](../core/criteria.html) - Different importance measures for selecting what to sparsify
- [Granularity](../core/granularity.html) - Control what gets sparsified (weights, filters, etc.)
- [Schedules](../core/schedules.html) - Control sparsification progression during training
- [Pruner](../prune/pruner.html) - Structured pruning that removes filters entirely