# 自定义您自己的剪枝器

Torch-pruning是一种可扩展的工具，使您能够使用自定义的重要性标准和剪枝方案创建自己的剪枝器。例如，您可以使用torch-pruning实现[Slimming pruner](https://arxiv.org/abs/1708.06519)，该剪枝器利用批量归一化（BN）中的缩放参数来识别和删除不重要的通道。



In [1]:
import warnings
warnings.filterwarnings('ignore')
import sys, os
sys.path.append(os.path.abspath("../"))

import torch
import torch.nn as nn
from torchvision.models import resnet18
import torch_pruning as tp

### 1. Pruner Definition

Slimming Pruner利用批归一化（BN）层中的缩放因子来确定不同通道的重要性得分。该技术遵循“训练-剪枝-微调”范式，其中包括对原始模型进行稀疏训练。在Torch-Pruning中，基类“tp.pruner.MetaPruner”提供了一个方便的“.regularize(model)”方法，用于稀疏训练。我们的第一个任务是实现这样的接口，以便有效地正则化BN参数。



In [2]:
class MySlimmingPruner(tp.pruner.MetaPruner):
    def regularize(self, model, reg):
        for m in model.modules():
            if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)) and m.affine==True:
                m.weight.grad.data.add_(reg*torch.sign(m.weight.data)) # Lasso for sparsity


### 2. Importance function
现在，我们需要一个新的重要性准则来进行削减，该准则比较BN中的缩放参数的大小。在这项工作中，重要性准则是一个可调用的函数或对象，它接受一组“tp.PruningGroup”作为输入。 “tp.PruningGroup”记录所有耦合层以及它们的剪枝索引。我们可以扫描该组来设计自己的重要性函数，如下所示：



In [4]:
class MySlimmingImportance(tp.importance.Importance):
    def __call__(self, group, **kwargs):
        #note that we have multiple BNs in a group, 
        # we store layer-wise scores in a list and then reduce them to get the final results
        group_imp = [] # (num_bns, num_channels) 
        # 1. 遍历group以估计重要性
        for dep, idxs in group:
            layer = dep.target.module # 获取目标模型
            prune_fn = dep.handler    # 获取目标模型的剪枝函数，在本例中未使用
            if isinstance(layer, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)) and layer.affine:
                local_imp = torch.abs(layer.weight.data)
                group_imp.append(local_imp)
        if len(group_imp)==0: return None # 如果组中不包含BN层，则返回None
        # 2. 将组重要性减少到1-D分数向量。这里我们使用跨层的平均分数。
        group_imp = torch.stack(group_imp, dim=0).mean(dim=0) 
        return group_imp # (num_channels, )

# 您可以实现任何重要性函数，只要它将组转换为1-D分数向量。

# 以下类是返回随机分数向量的重要性函数的示例。
class RandomImportance(tp.importance.Importance):
    @torch.no_grad()
    def __call__(self, group, **kwargs):
        _, idxs = group[0]
        return torch.rand(len(idxs)) # 返回与idxs长度相同的随机分数向量。


### 3. Pruning
现在让我们利用定制的修剪器来精简resnet-18模型

In [5]:
model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)

In [7]:
# 0. importance criterion 
imp = MySlimmingImportance()

# 1. ignore some layers that should not be pruned, e.g., the final classifier layer.
ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

# 2. Pruner initialization
iterative_steps = 5 # You can prune your model to the target sparsity iteratively.
pruner = MySlimmingPruner(
    model, 
    example_inputs, 
    global_pruning=False, # If False, a uniform sparsity will be assigned to different layers.
    importance=imp, # importance criterion for parameter selection
    iterative_steps=iterative_steps, # the number of iterations to achieve target sparsity
    ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)

使用“pruner.regularize”进行稀疏训练。记得在“optimizer.step()”之前对模型进行正则化。

In [8]:
# Training
for _ in range(100):
    pass
    # optimizer.zero_grad()
    # ...
    # pruner.regularize(model, reg=1e-5)
    # loss.backward()
    # optimizer.step()

Pruning and finetuning

In [10]:
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    pruner.step()

    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    # print(model)
    print(model(example_inputs).shape)
    print(
        "  Iter %d/%d, Params: %.2f M => %.2f M"
        % (i+1, iterative_steps, base_nparams / 1e6, nparams / 1e6)
    )
    print(
        "  Iter %d/%d, MACs: %.2f G => %.2f G"
        % (i+1, iterative_steps, base_macs / 1e9, macs / 1e9)
    )
    print("="*16)
    # finetune your model here
    # finetune(model)
    # ...

torch.Size([1, 1000])
  Iter 1/5, Params: 3.06 M => 3.06 M
  Iter 1/5, MACs: 0.49 G => 0.49 G
torch.Size([1, 1000])
  Iter 2/5, Params: 3.06 M => 3.06 M
  Iter 2/5, MACs: 0.49 G => 0.49 G
torch.Size([1, 1000])
  Iter 3/5, Params: 3.06 M => 3.06 M
  Iter 3/5, MACs: 0.49 G => 0.49 G
torch.Size([1, 1000])
  Iter 4/5, Params: 3.06 M => 3.06 M
  Iter 4/5, MACs: 0.49 G => 0.49 G
torch.Size([1, 1000])
  Iter 5/5, Params: 3.06 M => 3.06 M
  Iter 5/5, MACs: 0.49 G => 0.49 G
