In [None]:
#| default_exp prune.pruner

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_pruning as tp
from torch_pruning.pruner import function

import pickle
from itertools import cycle
from fastcore.basics import store_attr, listify, true
from fasterai.core.criteria import *
from fastai.vision.all import *


from torch_pruning.pruner.algorithms.scheduler import linear_scheduler
from torch.fx import symbolic_trace

In [None]:
#| include: false
from nbdev.showdoc import *

In [None]:
#| export
class Pruner():
    "Structured pruning for neural networks using torch_pruning"
    def __init__(self, model, pruning_ratio, context, criteria, schedule=linear_scheduler, ignored_layers=None, example_inputs=torch.randn(1, 3, 224, 224), *args, **kwargs):
        store_attr()
        self.num_heads = {}
        if not self.ignored_layers: self.get_ignored_layers(self.model)
        if self.pruning_ratio>1: self.pruning_ratio = self.pruning_ratio/100
        if not (0 < self.pruning_ratio <= 1):
            raise ValueError(f"pruning_ratio must be in range (0, 1], got {self.pruning_ratio}")
        self.pruner = tp.pruner.MetaPruner(
        self.model,
        example_inputs=self.example_inputs.to(next(self.model.parameters()).device),
        importance=self.group_importance,
        pruning_ratio=self.pruning_ratio, 
        ignored_layers=self.ignored_layers,
        global_pruning=True if self.context=='global' else False,
        num_heads = self.num_heads,
        iterative_pruning_ratio_scheduler=self.schedule,
        *args, 
        **kwargs
        )
          
    def prune_model(self):
        "Execute one pruning step and restore attention layer configurations"
        self.pruner.step()
        self.restore_attention_layers()


    def get_linear_layers_to_ignore(self, 
                                    model: nn.Module  # The model to analyze
    ):
        "Find and ignore output Linear layers to preserve model output dimensions"
        traced = symbolic_trace(model)
        for node in traced.graph.nodes:
            if node.op == "output":  # Identify the output
                for input_node in node.all_input_nodes:
                    if input_node.target:  # Find the corresponding layer
                        module = dict(model.named_modules()).get(input_node.target)
                        if isinstance(module, torch.nn.Linear):
                            self.ignored_layers.append(module)
                            print(f"Ignoring output layer: {module}")


    def get_attention_layers_to_ignore(self, 
                                       model: nn.Module  # The model to analyze
    ):
        "Find and ignore attention layers (qkv projections) to preserve attention structure"
        for module in model.modules():
            if hasattr(module, 'num_heads'):
                if hasattr(module, 'qkv'):
                    self.ignored_layers.append(module.qkv)
                    self.num_heads[module.qkv] = module.num_heads
                    print(f"Attention layer ignored: {module.qkv}, num_heads={module.num_heads}")
                elif hasattr(module, 'qkv_proj'):
                    self.ignored_layers.append(module.qkv_proj)
                    self.num_heads[module.qkv_proj] = module.num_heads
                    print(f"Attention layer ignored: {module.qkv_proj}, num_heads={module.num_heads}")

    
    def get_ignored_layers(self, 
                           model: nn.Module  # The model to analyze
    ):
        "Build list of layers to ignore during pruning"
        self.ignored_layers = []
        self.get_linear_layers_to_ignore(model)
        self.get_attention_layers_to_ignore(model)
        print(f"Total ignored layers: {len(self.ignored_layers)}")
    
                
    def restore_attention_layers(self):
        "Restore num_heads and head_dim attributes after pruning attention layers"
        for m in self.model.modules():
            if hasattr(m, 'num_heads'):
                if hasattr(m, 'qkv'):
                    m.num_heads = self.num_heads[m.qkv]
                    m.head_dim = m.qkv.out_features // (3 * m.num_heads)
                elif hasattr(m, 'qkv_proj'):
                    m.num_heads = self.num_heads[m.qkv_proj]
                    m.head_dim = m.qkv_proj.out_features // (3 * m.num_heads)


    def group_importance(self, group):
        "Compute importance scores for a dependency group"
        handler_map = {
            function.prune_conv_out_channels: 'filter',
            function.prune_linear_out_channels: 'row',
            function.prune_linear_in_channels: 'column',
            function.prune_conv_in_channels: 'shared_kernel',
        }
    
        group_imp = []
        group_idxs = []
    
        for i, (dep, idxs) in enumerate(group):
            if dep.handler in handler_map:
                impo = self.criteria(dep.target.module, handler_map.get(dep.handler), squeeze=True)
                group_imp.append(impo)
                group_idxs.append(group[i].root_idxs)
    
        if len(group_imp) == 0:
            return torch.tensor([])
            
        reduced_imp = torch.zeros_like(group_imp[0])
    
        for i, (imp, root_idxs) in enumerate(zip(group_imp, group_idxs)):
            imp = imp.to('cpu')
            reduced_imp = reduced_imp.to('cpu')
            reduced_imp.scatter_add_(0, torch.tensor(root_idxs, device=imp.device), imp)
    
        reduced_imp /= len(group_imp)
    
        return reduced_imp.to(default_device())

In [None]:
show_doc(Pruner.prune_model)

Let's try the `Pruner` with a VGG16 model

In [None]:
model = resnet18(); model

The `Pruner`can either remove filters based on `local` criteria (i.e. each layer will be trimmed of the same % of filters)

In [None]:
pruner = Pruner(model, 30, 'local', large_final)
pruner.prune_model()
print(model)

The `Pruner`can also remove filters based on `global` criteria (i.e. each layer will be trimmed of a different % of filters, but we specify the sparsity of the whole network)

In [None]:
model = resnet18()
pruner = Pruner(model, 50, 'global', large_final)
pruner.prune_model()
print(model)