-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MKL-DNN] Fully Connected #15226
[MKL-DNN] Fully Connected #15226
Changes from all commits
0217527
333123d
a086125
a6ed9d6
3a758f4
1b038bf
165a65c
15cb840
278b3ed
161c9ee
a36e01c
98aef65
df52aa1
4160b65
48373a6
f689a45
20d919e
85226b5
5bd0f7b
bb5b170
b5efc68
5853915
f2eb6a4
492ecbc
5979bfa
7f900c4
36dcdd3
6a63263
110945c
886d74d
f7ecfb7
68c0fda
7dd6ce3
678904c
04cbeeb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.h" | ||
#include <algorithm> | ||
#include <memory> | ||
#include <string> | ||
#include <vector> | ||
#include "paddle/fluid/framework/eigen.h" | ||
#include "paddle/fluid/framework/lod_tensor.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
namespace ir { | ||
|
||
void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's because the weights of the layer need to be transposed in the fc mkldnn pass. This allows the mkl-dnn's algorithm to execute much more efficiently. I can further explain why is that if necessary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, this makes sense, but cloud we reuse some codes, seems only weight need reorder? But another question, does this only for inference? If works on training, does it would cause some grad update issue? or print weight issue since format may be not nchw. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What code shall be re-used in this case? The pass is taking care of checking whether the Input has correct dimensions and applies the transpose only in this case. This op is only designed to work for inference only. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fc_mkldnn_pass now turns MKL-DNN's fully connected layer on if input dimensions are 2 or 4. |
||
PADDLE_ENFORCE(graph); | ||
Init("fc_mkldnn_pass", graph); | ||
|
||
auto* scope = param_scope(); | ||
PADDLE_ENFORCE(scope); | ||
|
||
GraphPatternDetector gpd; | ||
auto* x = gpd.mutable_pattern() | ||
->NewNode("fc_mkldnn_pass/x") | ||
->AsInput() | ||
->assert_is_op_input("fc", "Input"); | ||
patterns::FCMKLDNN fc_pattern(gpd.mutable_pattern(), "fc_mkldnn_pass"); | ||
fc_pattern(x, true /*with bias*/); | ||
|
||
int found_fc_count = 0; | ||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, | ||
Graph* g) { | ||
VLOG(4) << "Handle FC MKL-DNN pass"; | ||
if (!(graph->Has("use_mkldnn") && graph->Get<bool>("use_mkldnn"))) { | ||
VLOG(3) << "do not perform fc fuse"; | ||
return; | ||
} | ||
GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern); | ||
GET_IR_NODE_FROM_SUBGRAPH(weights, weights, fc_pattern); | ||
GET_IR_NODE_FROM_SUBGRAPH(bias, bias, fc_pattern); | ||
GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern); | ||
|
||
OpDesc* desc = fc->Op(); | ||
auto in_size = fc->inputs[0]->Var()->GetShape().size(); | ||
if (in_size != 2 && in_size != 4) { | ||
VLOG(3) << "Do not enable FC MKL-DNN for dimensions different than 2 & 4"; | ||
return; | ||
} | ||
desc->SetAttr("use_mkldnn", true); | ||
PADDLE_ENFORCE(subgraph.count(x)); | ||
|
||
found_fc_count++; | ||
}; | ||
|
||
gpd(graph, handler); | ||
|
||
AddStatis(found_fc_count); | ||
} | ||
|
||
} // namespace ir | ||
} // namespace framework | ||
} // namespace paddle | ||
|
||
REGISTER_PASS(fc_mkldnn_pass, paddle::framework::ir::FCMKLDNNPass); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
#pragma once | ||
#include <memory> | ||
#include "paddle/fluid/framework/ir/fuse_pass_base.h" | ||
#include "paddle/fluid/framework/ir/graph.h" | ||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" | ||
#include "paddle/fluid/framework/ir/pass.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
namespace ir { | ||
|
||
/* | ||
* Transpose weights of FC to comply with MKL-DNN interface | ||
*/ | ||
class FCMKLDNNPass : public FusePassBase { | ||
public: | ||
virtual ~FCMKLDNNPass() {} | ||
|
||
protected: | ||
void ApplyImpl(ir::Graph* graph) const; | ||
}; | ||
|
||
} // namespace ir | ||
} // namespace framework | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we have
PDNode *patterns::FC::operator()
, why we needPDNode *patterns::FCMKLDNN::operator
again?We don't have
PDNode *patterns::xxxMKLDNN:: operator()
in this file.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because FC pattern detector searches for mul + elementwise_add pattern, while FCMKLDNN searches for an FC op pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did not get this point actually.
From pattern side, is there some different from the original FC pattern?
Maybe only need add
use_mkldnn=True
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason why you transpose the weights of this Op in the pass is to avoid duplicated transpose in the Op's
compute()
. How much elapsed time of this weight transpose?However, from framework perspective,
xxx_mkldnn_op
should have the same behavior withxxx_op
, they could have different kernels but with the same input/output/weights, i.e, fc's weights should be transposed in the mkldnn kernel.Currently,
It is not enough to set use_mkldnn to true
may make users confused.@jianhang-liu How do you think about it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no need to do that, because if we save the optimized model with all the passes applied, once it will be loaded again it will execute just fine in MKL-DNN environment, because the passes would be already applied and the weights transposed.
There is no point in running
saved optimized model
in other environment anyway, because the passes such as conv + batch_norm + bias have already introduced changes which are only applicable in MKL-DNN-only environment (there is no support for bias in reference conv).Is that what you meant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it for Fp32, but for INT8, we will running
saved optimized model
in other environment. #17097 base on this PR and use weights transpose as well.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I understand, if kernel from #17097 will be accustomed to use transposed weights, then everything should be set and ready for running
saved optimized model
in int8 python environment. If not, it is always possible to transpose the weights back using transpiler.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Current implementation doesn't modify weights. It transposes them and stores them internally in execution context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.