-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[slim] Refine framework of slim and add filter pruning strategy #16226
[slim] Refine framework of slim and add filter pruning strategy #16226
Conversation
wanghaoshuang
commented
Mar 15, 2019
- Add the framework of paddle slim
- Add filter pruning strategy
1. Add framework of paddle slim 2. Add filter pruning strategy test=develop
… release_slim_pruning test=develop
dab6dea
to
cbab069
Compare
test=develop
6918572
to
fa2c64b
Compare
test=develop
test=develop
test=develop
test=develop
test=develop
… release_slim_pruning test=develop
d04fabb
to
3679af9
Compare
b707d6f
to
cac2867
Compare
test=develop
test=develop
test=develop
test=develop
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不是剪切和蒸馏不使用graph了吗?还需要graph_wrapper吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了以后方便切换到IrGraph
@@ -13,9 +13,10 @@ | |||
# limitations under the License. | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quantization Strategy is not included?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
准备另起PR提交Quantization Strategy相关的内容
@@ -17,7 +17,7 @@ | |||
import yaml | |||
from collections import OrderedDict | |||
from ..prune import * | |||
from .compress_pass import * | |||
from ..quantization import * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see where is the quantization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quantization是主干已经有的一个module: https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle/fluid/contrib/slim/quantization
cached_id(int): The id of dataset sampled. Evaluations with same cached_id use the same sampled dataset. default: 0. | ||
""" | ||
np.random.seed(cached_id) | ||
cache_path = cache_path + "/" + str(cached_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议使用os.path.join(dir0, dir1,..., file),以自动补充"/"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Thx.
Load the context from file. | ||
""" | ||
with open(file_name) as context_file: | ||
data = pickle.load(context_file) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pickle在python3下面测试过吗,上次我们提交video的代码,在python3下面使用pickle.load报错,修改成下面的形式了
if python_ver < (3, 0):
data_loaded = pickle.load(open(pickle_path, 'rb'))
else:
data_loaded = pickle.load(open(pickle_path, 'rb'), encoding='bytes')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Thx.
Fix log and comments. test=develop
6d075e2
to
6113fff
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not finished yet
batch += 1 | ||
yield data | ||
|
||
return s_reader |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why needs these readers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了提升data feeder的速度,但并不是所有情况都能提速。这个逻辑放在压缩工具内部确实不太合适。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feeder加速,用户可以外面做,建议去掉~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx. Done.
teacher_graphs: The teacher graphs used in distillation strategies. | ||
train_optimizer: The optimizer used to append backward ops and | ||
optimization ops into train_graph. | ||
distiller_optimizer: The optimizer used by distillation strategies. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In distill, how to set train_optimizer and distiller_optimizer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在distiller框架中,没有统一的用户接口,不同的任务都有不同的入口,比如:分类任务使用接口, language model 使用接口
其中,分类任务接口提供了蒸馏功能,language model 没有提供蒸馏功能。
在分类任务接口compress_classifier.py中,optimizer是固定的SGD,而且蒸馏策略只能用该optimizer. 如果用户需要切换optimizer,只能修改compress_classifier.py文件,该文件夹杂了各种对压缩策略的调度逻辑,对用户非常不友好。
对于分类任务,蒸馏没有用独立的optimizer, 它做蒸馏的步骤如下:
- 调用knowledge distillation policy的forward进行前向计算
- 使用所有压缩策略共用的optimizer进行反向计算
综上,distiller的用户接口设计、optimizer的使用方式和蒸馏策略的调用方式都不具参考价值。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
distiller给出的大部分demo都是基于compress_classifier.py的, 用户能调整的也就是SGD optimizer的一些参数。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
抱歉,我理解错问题了。这里我会补充下注释,详细说明下两个optimizer各自的用途。多谢。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
抱歉,我上面的意思是蒸馏压缩算法时,这两个optimizer如何设置需要解释清楚~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx. Done.
logger.info('Latest evaluations: {}'.format(results)) | ||
return abs(results[1] - results[0]) / results[0] < delta | ||
|
||
def run_eval_graph(self, sampled_rate=None, cached_id=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the run_eval_graph
in the Context
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了方便各种策略随时评估当前graph的性能,该方法可放在:
- 每个策略类中:重复实现,策略实现者比较麻烦。
- CompressPass类中:每个策略访问不到
- GraphWrapper中?
- 单独实现一个ExecutorHelper?
if 'init_model' in factory.compress_pass: | ||
self.init_model = factory.compress_pass['init_model'] | ||
|
||
def _init_model(self, context): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discuss before, Pass is only used to transform graph, no include RUNING
(train/eval)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CompressPass其实不是pass,我把名字改成Compressor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree to change name. Make the consistent meaning of Pass in Python and C++. How do you think? @panyx0718 @wzzju
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed CompressPass to Compressor.
Pruner used to pruning parameters by groups. | ||
""" | ||
|
||
def __init__(self, pruning_axis, criterions): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you might want to add some comments for this and many others. not easy to understand the codes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx. Fixed.
class: 'SensitivePruneStrategy' | ||
pruner: 'pruner_1' | ||
start_epoch: 1 | ||
delta_rate: 0.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you might want to document this metrics
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx. Fixed.
num_steps: 1 | ||
eval_rate: 0.5 | ||
pruned_params: 'conv6_sep_weights' | ||
sensitivities_file: 'mobilenet_acc_top1_sensitive.data' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx. Fixed.
num_steps: 1 | ||
eval_rate: 0.5 | ||
pruned_params: '.*_sep_weights' | ||
sensitivities_file: 'mobilenet_acc_top1_sensitive.data' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx. Fixed.
if 'init_model' in factory.compress_pass: | ||
self.init_model = factory.compress_pass['init_model'] | ||
|
||
def _init_model(self, context): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree to change name. Make the consistent meaning of Pass in Python and C++. How do you think? @panyx0718 @wzzju
strategies = self.strategies | ||
if self.checkpoint_path: | ||
if not os.path.exists(self.checkpoint_path): | ||
os.makedirs(self.checkpoint_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是load_checkpoint,路径还可能不存在,要makerdirs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx. Fixed.
|
||
for epoch in range(self.epoch): | ||
reader = feed_reader( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议去掉feed_reader,用户外面传进来的reader,可能已经是多线程/进程了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx. Fixed.
""" | ||
Runing evaluation. | ||
""" | ||
results, names = context.run_eval_graph() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个class是各种压缩算法公用的吧? 如果是Context里的run_eval_graph可以直接放这里吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compressor(CompressPass)对象包含了strategy, 从层级关系上看,strategy不能直接访问Compressor对象的。
如果strategy.on_compression_epoch(compressor):
- 两个对象相互依赖
- strategy能访问到的内容太多,比如可以直接调用compressor.run()形成死循环
context相当于把compressor中允许且需要strategy访问的信息和能力封装起来,供strategy使用。同时,compressor自己也可以使用,比如这一行。
graph.program.global_block().var(name) for name in fetches | ||
] | ||
results = self.exe.run(graph.program, | ||
def run(self, graph, scope, data=None, feed=None, fetches=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comments, data是什么类型?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx. Fixed.
graph.program.global_block().var(name) for name in fetches | ||
] | ||
results = self.exe.run(graph.program, | ||
def run(self, graph, scope, data=None, feed=None, fetches=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comments, data, feed, fetches是什么类型?
Args: | ||
context(slim.core.Context): The context storing all information used to evaluate the current model. | ||
sampled_rate(float): The sampled rate used to sample partial data for evaluation. None means using all data in eval_reader. default: None. | ||
cached_id(int): The id of dataset sampled. Evaluations with same cached_id use the same sampled dataset. default: 0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
似乎不太明白cache_id的解释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cache_id是随机抽取的数据子集的唯一标识。
2. Fix cache reader 3. Rename CompressPass to Compressor 4. Add comments for distiller optimizer 5. Remove unused pruner currently 6. Add some comments. 7. Change API.spec test=develop
test=develop
test=develop
23bd353
to
b7f8b4f
Compare
test=develop
b7f8b4f
to
071accc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG api.spec