Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[slim] Refine framework of slim and add filter pruning strategy #16226

Merged
merged 26 commits into from
Mar 23, 2019

Conversation

wanghaoshuang
Copy link
Contributor

  1. Add the framework of paddle slim
  2. Add filter pruning strategy

1. Add framework of paddle slim
2. Add filter pruning strategy
test=develop
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是剪切和蒸馏不使用graph了吗?还需要graph_wrapper吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了以后方便切换到IrGraph

@@ -13,9 +13,10 @@
# limitations under the License.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quantization Strategy is not included?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

准备另起PR提交Quantization Strategy相关的内容

@@ -17,7 +17,7 @@
import yaml
from collections import OrderedDict
from ..prune import *
from .compress_pass import *
from ..quantization import *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see where is the quantization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cached_id(int): The id of dataset sampled. Evaluations with same cached_id use the same sampled dataset. default: 0.
"""
np.random.seed(cached_id)
cache_path = cache_path + "/" + str(cached_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议使用os.path.join(dir0, dir1,..., file),以自动补充"/"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thx.

Load the context from file.
"""
with open(file_name) as context_file:
data = pickle.load(context_file)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pickle在python3下面测试过吗,上次我们提交video的代码,在python3下面使用pickle.load报错,修改成下面的形式了
if python_ver < (3, 0):
data_loaded = pickle.load(open(pickle_path, 'rb'))
else:
data_loaded = pickle.load(open(pickle_path, 'rb'), encoding='bytes')

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thx.

Copy link
Contributor

@qingqing01 qingqing01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not finished yet

batch += 1
yield data

return s_reader
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why needs these readers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了提升data feeder的速度,但并不是所有情况都能提速。这个逻辑放在压缩工具内部确实不太合适。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feeder加速,用户可以外面做,建议去掉~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. Done.

teacher_graphs: The teacher graphs used in distillation strategies.
train_optimizer: The optimizer used to append backward ops and
optimization ops into train_graph.
distiller_optimizer: The optimizer used by distillation strategies.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In distill, how to set train_optimizer and distiller_optimizer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在distiller框架中,没有统一的用户接口,不同的任务都有不同的入口,比如:分类任务使用接口, language model 使用接口

其中,分类任务接口提供了蒸馏功能,language model 没有提供蒸馏功能。
分类任务接口compress_classifier.py中,optimizer是固定的SGD,而且蒸馏策略只能用该optimizer. 如果用户需要切换optimizer,只能修改compress_classifier.py文件,该文件夹杂了各种对压缩策略的调度逻辑,对用户非常不友好。

对于分类任务,蒸馏没有用独立的optimizer, 它做蒸馏的步骤如下:

  1. 调用knowledge distillation policy的forward进行前向计算
  2. 使用所有压缩策略共用的optimizer进行反向计算

综上,distiller的用户接口设计、optimizer的使用方式和蒸馏策略的调用方式都不具参考价值。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

distiller给出的大部分demo都是基于compress_classifier.py的, 用户能调整的也就是SGD optimizer的一些参数。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

抱歉,我理解错问题了。这里我会补充下注释,详细说明下两个optimizer各自的用途。多谢。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

抱歉,我上面的意思是蒸馏压缩算法时,这两个optimizer如何设置需要解释清楚~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. Done.

logger.info('Latest evaluations: {}'.format(results))
return abs(results[1] - results[0]) / results[0] < delta

def run_eval_graph(self, sampled_rate=None, cached_id=0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the run_eval_graph in the Context ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了方便各种策略随时评估当前graph的性能,该方法可放在:

  1. 每个策略类中:重复实现,策略实现者比较麻烦。
  2. CompressPass类中:每个策略访问不到
  3. GraphWrapper中?
  4. 单独实现一个ExecutorHelper?

if 'init_model' in factory.compress_pass:
self.init_model = factory.compress_pass['init_model']

def _init_model(self, context):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discuss before, Pass is only used to transform graph, no include RUNING(train/eval)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CompressPass其实不是pass,我把名字改成Compressor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree to change name. Make the consistent meaning of Pass in Python and C++. How do you think? @panyx0718 @wzzju

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed CompressPass to Compressor.

panyx0718
panyx0718 previously approved these changes Mar 21, 2019
Pruner used to pruning parameters by groups.
"""

def __init__(self, pruning_axis, criterions):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might want to add some comments for this and many others. not easy to understand the codes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. Fixed.

class: 'SensitivePruneStrategy'
pruner: 'pruner_1'
start_epoch: 1
delta_rate: 0.2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might want to document this metrics

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. Fixed.

num_steps: 1
eval_rate: 0.5
pruned_params: 'conv6_sep_weights'
sensitivities_file: 'mobilenet_acc_top1_sensitive.data'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. Fixed.

num_steps: 1
eval_rate: 0.5
pruned_params: '.*_sep_weights'
sensitivities_file: 'mobilenet_acc_top1_sensitive.data'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. Fixed.

if 'init_model' in factory.compress_pass:
self.init_model = factory.compress_pass['init_model']

def _init_model(self, context):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree to change name. Make the consistent meaning of Pass in Python and C++. How do you think? @panyx0718 @wzzju

strategies = self.strategies
if self.checkpoint_path:
if not os.path.exists(self.checkpoint_path):
os.makedirs(self.checkpoint_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是load_checkpoint,路径还可能不存在,要makerdirs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. Fixed.


for epoch in range(self.epoch):
reader = feed_reader(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议去掉feed_reader,用户外面传进来的reader,可能已经是多线程/进程了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. Fixed.

"""
Runing evaluation.
"""
results, names = context.run_eval_graph()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个class是各种压缩算法公用的吧? 如果是Context里的run_eval_graph可以直接放这里吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compressor(CompressPass)对象包含了strategy, 从层级关系上看,strategy不能直接访问Compressor对象的。
如果strategy.on_compression_epoch(compressor):

  • 两个对象相互依赖
  • strategy能访问到的内容太多,比如可以直接调用compressor.run()形成死循环

context相当于把compressor中允许且需要strategy访问的信息和能力封装起来,供strategy使用。同时,compressor自己也可以使用,比如这一行。

graph.program.global_block().var(name) for name in fetches
]
results = self.exe.run(graph.program,
def run(self, graph, scope, data=None, feed=None, fetches=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comments, data是什么类型?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. Fixed.

graph.program.global_block().var(name) for name in fetches
]
results = self.exe.run(graph.program,
def run(self, graph, scope, data=None, feed=None, fetches=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comments, data, feed, fetches是什么类型?

Args:
context(slim.core.Context): The context storing all information used to evaluate the current model.
sampled_rate(float): The sampled rate used to sample partial data for evaluation. None means using all data in eval_reader. default: None.
cached_id(int): The id of dataset sampled. Evaluations with same cached_id use the same sampled dataset. default: 0.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

似乎不太明白cache_id的解释

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache_id是随机抽取的数据子集的唯一标识。

2. Fix cache reader
3. Rename CompressPass to Compressor
4. Add comments for distiller optimizer
5. Remove unused pruner currently
6. Add some comments.
7. Change API.spec
test=develop
qingqing01
qingqing01 previously approved these changes Mar 22, 2019
@wanghaoshuang wanghaoshuang force-pushed the release_slim_pruning branch 2 times, most recently from 23bd353 to b7f8b4f Compare March 22, 2019 14:59
Copy link
Contributor

@chengduoZH chengduoZH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG api.spec

@wanghaoshuang wanghaoshuang merged commit 2e5831f into PaddlePaddle:develop Mar 23, 2019
@wanghaoshuang wanghaoshuang deleted the release_slim_pruning branch May 20, 2022 03:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants