Skip to content

Latest commit

 

History

History
284 lines (221 loc) · 11.4 KB

self_defined_filter_pruning.md

File metadata and controls

284 lines (221 loc) · 11.4 KB

自定义剪裁

1. 概述

该教程介绍如果在PaddleSlim提供的接口基础上快速自定义Filters剪裁策略。 在PaddleSlim中,所有剪裁FiltersPruner继承自基类FilterPrunerFilterPruner中自定义了一系列通用方法,用户只需要重载实现FilterPrunercal_mask接口,cal_mask接口定义如下:

def cal_mask(self, var_name, pruned_ratio, group):
    raise NotImplemented()

cal_mask接口接受的参数说明如下:

  • var_name: 要剪裁的目标变量,一般为卷积层的权重参数的名称。在Paddle中,卷积层的权重参数格式为[output_channel, input_channel, kernel_size, kernel_size],其中,output_channel为当前卷积层的输出通道数,input_channel为当前卷积层的输入通道数,kernel_size为卷积核大小。
  • pruned_ratio: 对名称为var_name的变量的剪裁率。
  • group: 与待裁目标变量相关的所有变量的信息。

1.1 Group概念介绍

图1-1 卷积层关联关系示意图

如图1-1所示,在给定模型中有两个卷积层,第一个卷积层有3个filters,第二个卷积层有2个filters。如果删除第一个卷积绿色的filter,第一个卷积的输出特征图的通道数也会减1,同时需要删掉第二个卷积层绿色的kernels。如上所述的两个卷积共同组成一个group,表示如下:

group = {
            "conv_1.weight":{
                "pruned_dims": [0],
                "layer": conv_layer_1,
                "var": var_instance_1,
                "value": var_value_1,
            },
            "conv_2.weight":{
                "pruned_dims": [1],
                "layer": conv_layer_2,
                "var": var_instance_2,
                "value": var_value_2,
            }
        }

在上述表示group的数据结构示例中,conv_1.weight为第一个卷积权重参数的名称,其对应的value也是一个dict实例,存放了当前参数的一些信息,包括:

  • pruned_dims: 类型为list<int>,表示当前参数在哪些维度上被裁。
  • layer: 类型为paddle.nn.Layer, 表示当前参数所在Layer
  • var: 类型为paddle.Tensor, 表示当前参数对应的实例。
  • value: 类型为numpy.array类型,待裁参数所存的具体数值,方便开发者使用。

图1-2为更复杂的情况,其中,Add操作的所有输入的通道数需要保持一致,Concat操作的输出通道数的调整可能会影响到所有输入的通道数,因此group中可能包含多个卷积的参数或变量,可以是:卷积权重、卷积bias、batch norm相关参数等。

图1-2 复杂网络示例

2. 定义模型

import paddle
from paddle.vision.models import mobilenet_v1
net = mobilenet_v1(pretrained=False)
paddle.summary(net, (1, 3, 32, 32))

3. L2NormFilterPruner

该小节参考L1NormFilterPruner实现L2NormFilterPruner,方式为集成FIlterPruner并重载cal_mask接口。代码如下所示:

import numpy as np
from paddleslim.dygraph import FilterPruner

class L2NormFilterPruner(FilterPruner):
    def __init__(self, model, inputs, sen_file=None, opt=None):
        super(L2NormFilterPruner, self).__init__(
            model, inputs, sen_file=sen_file, opt=opt)

    def cal_mask(self, pruned_ratio, collection):
        var_name = collection.master_name
        pruned_axis = collection.master_axis
        value = collection.values[var_name]
        groups = 1
        for _detail in collection.all_pruning_details():
            assert (isinstance(_detail.axis, int))
            if _detail.axis == 1:
                _groups = _detail.op.attr('groups')
                if _groups is not None and _groups > 1:
                    groups = _groups
                    break

        reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis]
        scores = np.sqrt(np.sum(np.square(value), axis=tuple(reduce_dims)))
        if groups > 1:
            scores = scores.reshape([groups, -1])
            scores = np.mean(scores, axis=1)

        sorted_idx = scores.argsort()
        pruned_num = int(round(len(sorted_idx) * pruned_ratio))
        pruned_idx = sorted_idx[:pruned_num]

        mask_shape = [value.shape[pruned_axis]]
        mask = np.ones(mask_shape, dtype="int32")
        if groups > 1:
            mask = mask.reshape([groups, -1])
        mask[pruned_idx] = 0
        return mask.reshape(mask_shape)

如上述代码所示,我们重载了FilterPruner基类的cal_mask方法,并在L1NormFilterPruner代码基础上,修改了计算通道重要性的语句,将其修改为了计算L2Norm的逻辑:

scores = np.sqrt(np.sum(np.square(value), axis=tuple(reduce_dims)))

接下来定义一个L2NormFilterPruner对象,并调用prune_var方法对单个卷积层进行剪裁,prune_var方法继承自FilterPruner,开发者不用再重载实现。 按以下代码调用prune_var方法后,参数名称为conv2d_0.w_0的卷积层会被裁掉50%的filters,与之相关关联的后续卷积和BatchNorm相关的参数也会被剪裁。prune_var不仅会对待裁模型进行inplace的裁剪,还会返回保存裁剪详细信息的PruningPlan对象,用户可以直接打印PruningPlan对象内容。 最后,可以通过调用Prunerrestore方法,将已被裁剪的模型恢复到初始状态。

pruner = L2NormFilterPruner(net, [1, 3, 32, 32])
plan = pruner.prune_var("conv2d_0.w_0", 0, 0.5)
print(plan)
pruner.restore()

4. FPGMFilterPruner

参考:Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration

4.1 原理介绍

如图4-1所示,传统基于Norm统计方法的filter重要性评估方式的有效性取决于卷积层权重数值的分布,比较理想的分布式要满足两个条件:

  • 偏差(deviation)要大
  • 最小值要小(图4-1中v1)

满足上述条件后,我们才能裁掉更多Norm统计值较小的参数,如图4-1中红色部分所示。

图 4-1

而现实中的模型的权重分布如图4-2中绿色分布所示,总是有较小的偏差或较大的最小值。

图 4-2

考虑到上述传统方法的缺点,FPGM则用filter之间的几何距离来表示重要性,其遵循的原则就是:几何距离比较近的filters,作用也相近。 如图4-3所示,有3个filters,将各个filter展开为向量,并两两计算几何距离。其中,绿色filter的重要性得分就是它到其它两个filter的距离和,即0.7071+0.5831=1.2902。同理算出另外两个filters的得分,绿色filter得分最高,其重要性最高。

图 4-3

4.2 实现

以下代码通过继承FilterPruner并重载cal_mask实现了FPGMFilterPruner,其中,get_distance_sum用于计算第out_idx个filter的重要性。

import numpy as np
from paddleslim.dygraph import FilterPruner

class FPGMFilterPruner(FilterPruner):
    def __init__(self, model, inputs, sen_file=None, opt=None):
        super(FPGMFilterPruner, self).__init__(
            model, inputs, sen_file=sen_file, opt=opt)

    def cal_mask(self, pruned_ratio, collection):
        var_name = collection.master_name
        pruned_axis = collection.master_axis
        value = collection.values[var_name]
        groups = 1
        for _detail in collection.all_pruning_details():
            assert (isinstance(_detail.axis, int))
            if _detail.axis == 1:
                _groups = _detail.op.attr('groups')
                if _groups is not None and _groups > 1:
                    groups = _groups
                    break

        dist_sum_list = []
        for out_i in range(value.shape[0]):
            dist_sum = self.get_distance_sum(value, out_i)
            dist_sum_list.append(dist_sum)
        scores = np.array(dist_sum_list)

        if groups > 1:
            scores = scores.reshape([groups, -1])
            scores = np.mean(scores, axis=1)

        sorted_idx = scores.argsort()
        pruned_num = int(round(len(sorted_idx) * pruned_ratio))
        pruned_idx = sorted_idx[:pruned_num]
        mask_shape = [value.shape[pruned_axis]]
        mask = np.ones(mask_shape, dtype="int32")
        if groups > 1:
            mask = mask.reshape([groups, -1])
        mask[pruned_idx] = 0
        return mask.reshape(mask_shape)

    def get_distance_sum(self, value, out_idx):
        w = value.view()
        w.shape = value.shape[0], np.product(value.shape[1:])
        selected_filter = np.tile(w[out_idx], (w.shape[0], 1))
        x = w - selected_filter
        x = np.sqrt(np.sum(x * x, -1))
        return x.sum()

接下来声明一个FPGMFilterPruner对象进行验证:

pruner = FPGMFilterPruner(net, [1, 3, 32, 32])
plan = pruner.prune_var("conv2d_0.w_0", 0, 0.5)
print(plan)
pruner.restore()

5. 敏感度剪裁

在第3节和第4节,开发者自定义实现的L2NormFilterPrunerFPGMFilterPruner也继承了FilterPruner的敏感度计算方法sensitive和剪裁方法sensitive_prune

5.1 预训练

import paddle.vision.transforms as T
transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
train_dataset = paddle.vision.datasets.Cifar10(mode="train", backend="cv2",transform=transform)
val_dataset = paddle.vision.datasets.Cifar10(mode="test", backend="cv2",transform=transform)
from paddle.static import InputSpec as Input
optimizer = paddle.optimizer.Momentum(
        learning_rate=0.1,
        parameters=net.parameters())

inputs = [Input([None, 3, 32, 32], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
net = mobilenet_v1(pretrained=False)
model = paddle.Model(net, inputs, labels)
model.prepare(
        optimizer,
        paddle.nn.CrossEntropyLoss(),
        paddle.metric.Accuracy(topk=(1, 5)))
model.fit(train_dataset, epochs=2, batch_size=128, verbose=1)
result = model.evaluate(val_dataset,batch_size=128, log_freq=10)
print(result)

5.2 计算敏感度

pruner = FPGMFilterPruner(net, [1, 3, 32, 32], opt=optimizer)
def eval_fn():
        result = model.evaluate(
            val_dataset,
            batch_size=128)
        return result['acc_top1']
sen = pruner.sensitive(eval_func=eval_fn, sen_file="./fpgm_sen.pickle")
print(sen)

5.3 剪裁

from paddleslim.analysis import dygraph_flops
flops = dygraph_flops(net, [1, 3, 32, 32])
print(f"FLOPs before pruning: {flops}")
plan = pruner.sensitive_prune(0.4, skip_vars=["conv2d_26.w_0"])
flops = dygraph_flops(net, [1, 3, 32, 32])
print(f"FLOPs after pruning: {flops}")
print(f"Pruned FLOPs: {round(plan.pruned_flops*100, 2)}%")
result = model.evaluate(val_dataset,batch_size=128, log_freq=10)
print(f"before fine-tuning: {result}")

5.4 重训练

model.fit(train_dataset, epochs=2, batch_size=128, verbose=1)
result = model.evaluate(val_dataset,batch_size=128, log_freq=10)
print(f"after fine-tuning: {result}")