Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable the detection of subgraph composed of grad ops #21223

Merged
merged 23 commits into from
Feb 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
952615a
Add the first implememtation of fusion_group op #19621 (#3)
Xreki Nov 18, 2019
1de02e5
Enable generating code for a given subgraph. #21126 (#4)
Xreki Nov 18, 2019
7061d85
Enable the detection of subgraph of grad ops.
Xreki Nov 18, 2019
7ec4836
Merge branch 'develop' into fusion_group_final
Xreki Nov 20, 2019
6d1159e
Generate code for detected subgraph in fusion_group_pass.
Xreki Nov 20, 2019
e4f20ca
Add an option in BuildStrategy to enable fusion_group_pass and add un…
Xreki Nov 20, 2019
310b5b2
Merge branch 'develop' into fusion_group_final
Xreki Nov 22, 2019
7aebc6a
Fix a bug when checking whether the shape of all inputs are the same.
Xreki Nov 22, 2019
241c332
Merge branch 'develop' into fusion_group_final
Xreki Dec 4, 2019
582d0f6
Merge branch 'develop' into fusion_group_final
Xreki Dec 9, 2019
c59530a
Add debug information.
Xreki Dec 23, 2019
32ac208
Merge branch 'develop' into fusion_group_final
Xreki Jan 3, 2020
d6de64e
Merge branch 'develop' into fusion_group_final
Xreki Jan 6, 2020
ce73cd3
Remove subgraph_detector from inference/analysis to the common framew…
Xreki Jan 6, 2020
fc1ab12
Call subgraph_detector in fusion_group pass.
Xreki Jan 6, 2020
0127093
Disable fusion_group when WITH_GPU is OFF.
Xreki Jan 7, 2020
dd40dcb
Refine all PADDLE_ENFORCE message.
Xreki Jan 7, 2020
7fec60f
Fix the case that some inputs are not defined in grad ops, and set op…
Xreki Jan 7, 2020
1a5ebe9
Merge branch 'develop' into fusion_group_final
Xreki Jan 9, 2020
97b690b
Merge branch 'develop' into fusion_group_final
Xreki Jan 13, 2020
0e96b09
Merge branch 'develop' into fusion_group_final
Xreki Jan 13, 2020
11368b3
Merge branch 'develop' into fusion_group_final
Xreki Feb 6, 2020
160a3e5
Follow review comments.
Xreki Feb 6, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,14 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d

cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)

set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass buffer_shared_inplace_op_pass buffer_shared_cross_op_memory_reuse_pass)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto
sequential_execution_pass
modify_op_lock_and_record_event_pass
all_reduce_deps_pass
reference_count_pass
eager_deletion_pass
buffer_shared_inplace_op_pass
buffer_shared_cross_op_memory_reuse_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})

cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
Expand All @@ -91,23 +98,22 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo
DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context)
cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle)

set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass fuse_bn_act_pass
multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
lock_free_optimize_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
sync_batch_norm_pass runtime_context_cache_pass)
if(WITH_GPU)
set(IR_PASS_DEPS ${IR_PASS_DEPS} fusion_group_pass)
endif()
if(WITH_NGRAPH)
set(NGRAPH_BS_DEPS ngraph)
else()
set(NGRAPH_BS_DEPS)
set(IR_PASS_DEPS ${IR_PASS_DEPS} ngraph)
endif()

cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass fuse_bn_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
lock_free_optimize_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
sync_batch_norm_pass runtime_context_cache_pass
pass_builder
${NGRAPH_BS_DEPS})
cc_library(build_strategy SRCS build_strategy.cc DEPS pass_builder ${IR_PASS_DEPS})

if (WITH_MKLDNN)
target_link_libraries(build_strategy mkldnn_placement_pass)
Expand Down
14 changes: 13 additions & 1 deletion paddle/fluid/framework/details/build_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
void AppendOpFusePasses() {
AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_,
"fuse_relu_depthwise_conv_pass");
AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass");
#ifdef PADDLE_WITH_CUDA
AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass");
#endif
AppendPassWithCheck(strategy_.fuse_elewise_add_act_ops_,
"fuse_elewise_add_act_pass");
AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass");
// for single card training, fuse_all_reduce_ops is unnecessary.
// coalesce_grad_tensor_pass should be before of MultiDevPass.
AppendPassWithCheck(strategy_.fuse_all_reduce_ops_,
Expand Down Expand Up @@ -370,6 +373,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
"GPU, skipped.";
continue;
}
} else if (pass->Type() == "fusion_group_pass") {
pass->Set<bool>("use_gpu", new bool(use_cuda));
if (!use_cuda) {
LOG(WARNING) << "fusion_group_pass is only supported on GPU, skipped.";
continue;
}
} else if (pass->Type() == "fuse_bn_act_pass") {
if (!use_cuda) {
LOG(WARNING) << "fuse_bn_act_pass is only supported on "
Expand Down Expand Up @@ -427,3 +436,6 @@ USE_PASS(mkldnn_placement_pass);
#ifdef PADDLE_WITH_NGRAPH
USE_PASS(ngraph_subgraph_pass);
#endif
#ifdef PADDLE_WITH_CUDA
USE_PASS(fusion_group_pass);
#endif
3 changes: 2 additions & 1 deletion paddle/fluid/framework/details/build_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ struct BuildStrategy {
// Operator fusion
// TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have
// cycle.
bool fuse_elewise_add_act_ops_{false};
bool fuse_bn_act_ops_{false};
bool fuse_elewise_add_act_ops_{false};
bool enable_auto_fusion_{false};
// Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients
// should not be sparse types
boost::optional<bool> fuse_all_optimizer_ops_{false};
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
add_subdirectory(fuse_optimizer_ops_pass)
add_subdirectory(memory_optimize_pass)
add_subdirectory(multi_devices_graph_pass)
if(NOT APPLE AND NOT WIN32)
if(NOT APPLE AND NOT WIN32 AND WITH_GPU)
add_subdirectory(fusion_group)
endif()

Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/framework/ir/fusion_group/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
cc_library(code_generator SRCS operation.cc code_generator.cc code_generator_helper.cc DEPS graph)
cc_library(code_generator
SRCS operation.cc code_generator.cc code_generator_helper.cc
DEPS graph subgraph_detector)
if(WITH_GPU)
cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor graph_viz_pass)
endif()

cc_library(fusion_group_pass
SRCS fusion_group_pass.cc elementwise_group_detector.cc
DEPS graph_pattern_detector pass code_generator)
DEPS subgraph_detector fuse_pass_base code_generator device_code)
cc_test(test_fusion_group_pass SRCS fusion_group_pass_tester.cc DEPS fusion_group_pass graph_viz_pass)
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/fusion_group/code_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ CodeGenerator::CodeGenerator() {

std::string CodeGenerator::Generate(SubGraph* subgraph) {
std::vector<OperationExpression> expressions = ConvertToExpressions(subgraph);
return Generate(subgraph->func_name, expressions);
return Generate(subgraph->GetFuncName(), expressions);
}

static bool HasInput(Node* n, std::string name) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ std::vector<fusion_group::OperationExpression> TestMain(
std::string code_str = code_generator.Generate(subgraph);
VLOG(3) << code_str;

TestMainImpl(subgraph->func_name, code_str, cpu_tensors, n, input_ids,
TestMainImpl(subgraph->GetFuncName(), code_str, cpu_tensors, n, input_ids,
output_ids);

// Need to check the accuracy according to expressions.
Expand Down
144 changes: 49 additions & 95 deletions paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"

namespace paddle {
namespace framework {
Expand All @@ -26,20 +29,22 @@ static std::unordered_set<std::string> unary_op_types;

static std::unordered_set<std::string>& GetBinaryOpTypes() {
if (binary_op_types.empty()) {
binary_op_types = OperationMap::Instance().Find(0, 2);
binary_op_types =
OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 2);
}
return binary_op_types;
}

static std::unordered_set<std::string>& GetUnaryOpTypes() {
if (unary_op_types.empty()) {
unary_op_types = OperationMap::Instance().Find(0, 1);
unary_op_types =
OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 1);
}
return unary_op_types;
}

static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types,
Node* n) {
const Node* n) {
if (n && n->IsOp() && n->Op() && n->outputs.size() > 0U) {
auto iter = op_types.find(n->Op()->Type());
if (iter != op_types.end()) {
Expand All @@ -49,114 +54,63 @@ static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types,
return false;
}

static bool IsBinaryOp(Node* n) {
if (IsSpecifiedOp(GetBinaryOpTypes(), n) && n->inputs.size() == 2U) {
auto* x = n->inputs[0];
auto* y = n->inputs[1];
static bool IsGradOp(const Node* n) {
PADDLE_ENFORCE_EQ(n && n->IsOp() && n->Op(), true,
platform::errors::InvalidArgument(
"Expected node %p to be an operator node.", n));
std::string suffix = "_grad";
std::string op_type = n->Op()->Type();
size_t pos = op_type.rfind(suffix);
return pos != std::string::npos &&
pos == (op_type.length() - suffix.length());
}

std::vector<int64_t> x_shape;
std::vector<int64_t> y_shape;
if (x && x->IsVar() && x->Var()) {
x_shape = x->Var()->GetShape();
}
if (y && y->IsVar() && y->Var()) {
y_shape = y->Var()->GetShape();
}
if (x_shape.size() == 0U || x_shape.size() != y_shape.size()) {
static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
const std::vector<int64_t>& r) {
return l.size() != 0U && r.size() != 0U && l == r;
}

static bool IsBinaryOp(const Node* n) {
if (IsSpecifiedOp(GetBinaryOpTypes(), n)) {
if ((!IsGradOp(n) && n->inputs.size() != 2U) || n->inputs.size() == 0U) {
return false;
}
for (size_t i = 0; i < x_shape.size(); ++i) {
if (x_shape[i] != y_shape[i]) {

// The shape of all inputs should be the same.
std::vector<int64_t> shape_0;
for (size_t i = 0; i < n->inputs.size(); ++i) {
auto* in_i = n->inputs[i];
if (!(in_i && in_i->IsVar() && in_i->Var())) {
return false;
}
}
return true;
}
return false;
}

static bool IsUnaryOp(Node* n) { return IsSpecifiedOp(GetUnaryOpTypes(), n); }

bool ElementwiseGroupDetector::IsElementwiseOp(Node* n) {
return IsBinaryOp(n) || IsUnaryOp(n);
}

bool ElementwiseGroupDetector::IsInputOfElementwiseOp(Node* n,
std::string name) {
if (n && n->IsVar() && n->Var()) {
for (auto* op : n->outputs) {
if (IsElementwiseOp(op)) {
if (name.empty()) {
return true;
} else if (IsNthInput(n, op, name, 0)) {
return true;
std::vector<int64_t> shape_i = in_i->Var()->GetShape();
if (i == 0U) {
shape_0 = shape_i;
} else {
if (!IsEqualAndNotEmpty(shape_0, shape_i)) {
return false;
}
}
}
return true;
}
return false;
}

bool ElementwiseGroupDetector::IsOutputOfElementwiseOp(Node* n) {
if (n && n->IsVar() && n->Var()) {
for (auto* op : n->inputs) {
if (IsElementwiseOp(op)) {
return true;
}
}
}
return false;
static bool IsUnaryOp(const Node* n) {
return IsSpecifiedOp(GetUnaryOpTypes(), n);
}

int ElementwiseGroupDetector::Search(Node* n, std::vector<Node*> except_nodes) {
std::unordered_set<Node*> except_nodes_set;
for (size_t i = 0; i < except_nodes.size(); ++i) {
except_nodes_set.insert(except_nodes[i]);
}

int num_operations = 0;
if (IsElementwiseOp(n)) {
subgraph_.Insert(n);
num_operations += 1;
for (auto* var : n->inputs) {
subgraph_.Insert(var);
if (except_nodes_set.find(var) == except_nodes_set.end()) {
num_operations += Search(var, {n});
}
}
for (auto* var : n->outputs) {
subgraph_.Insert(var);
if (except_nodes_set.find(var) == except_nodes_set.end()) {
num_operations += Search(var, {n});
}
}
} else if (n && n->IsVar() && n->Var()) {
for (auto* op : n->inputs) {
if (IsElementwiseOp(op) &&
except_nodes_set.find(op) == except_nodes_set.end()) {
num_operations += Search(op, {n});
}
}
for (auto* op : n->outputs) {
if (IsElementwiseOp(op) &&
except_nodes_set.find(op) == except_nodes_set.end()) {
num_operations += Search(op, {n});
}
}
}
return num_operations;
bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
return IsBinaryOp(n) || IsUnaryOp(n);
}

int ElementwiseGroupDetector::operator()(Node* n) {
if (!IsOutputOfElementwiseOp(n) && IsInputOfElementwiseOp(n, "X")) {
name_ = n->Name();
subgraph_.Insert(n);
num_operations_ = Search(n, n->inputs);
VLOG(4) << "Detect elementwise subgraph begin with " << name_ << ", "
<< num_operations_ << " operations, " << GetSubgraph().GetNumNodes()
<< " nodes";
}
return num_operations_;
std::vector<std::vector<Node*>> ElementwiseGroupDetector::operator()(
Graph* graph) {
auto teller = [&](const Node* n) -> bool { return IsElementwiseOp(n); };

return SubgraphDetector(graph, teller)();
}

} // namespace fusion_group
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@ limitations under the License. */

#pragma once

#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/fusion_group/subgraph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"

namespace paddle {
Expand All @@ -27,21 +25,10 @@ namespace fusion_group {

class ElementwiseGroupDetector {
public:
int operator()(Node* n);

SubGraph GetSubgraph() const { return subgraph_; }

private:
bool IsElementwiseOp(Node* n);
bool IsInputOfElementwiseOp(Node* n, std::string name = "");
bool IsOutputOfElementwiseOp(Node* n);

int Search(Node* n, std::vector<Node*> except_nodes = {});
std::vector<std::vector<Node*>> operator()(Graph* graph);

private:
std::string name_;
int num_operations_{0};
SubGraph subgraph_;
bool IsElementwiseOp(const Node* n);
};

} // namespace fusion_group
Expand Down
Loading