Skip to content

Commit

Permalink
Call subgraph_detector in fusion_group pass.
Browse files Browse the repository at this point in the history
test=develop
  • Loading branch information
Xreki committed Jan 6, 2020
1 parent ce73cd3 commit fc1ab12
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 364 deletions.
6 changes: 4 additions & 2 deletions paddle/fluid/framework/ir/fusion_group/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
cc_library(code_generator SRCS operation.cc code_generator.cc code_generator_helper.cc DEPS graph)
cc_library(code_generator
SRCS operation.cc code_generator.cc code_generator_helper.cc
DEPS graph subgraph_detector)
if(WITH_GPU)
cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor graph_viz_pass)
endif()

cc_library(fusion_group_pass
SRCS fusion_group_pass.cc elementwise_group_detector.cc
DEPS graph_pattern_detector pass code_generator device_code)
DEPS subgraph_detector fuse_pass_base code_generator device_code)
cc_test(test_fusion_group_pass SRCS fusion_group_pass_tester.cc DEPS fusion_group_pass graph_viz_pass)
151 changes: 20 additions & 131 deletions paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"

namespace paddle {
namespace framework {
Expand All @@ -27,20 +28,22 @@ static std::unordered_set<std::string> unary_op_types;

static std::unordered_set<std::string>& GetBinaryOpTypes() {
if (binary_op_types.empty()) {
binary_op_types = OperationMap::Instance().Find(0, 2);
binary_op_types =
OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 2);
}
return binary_op_types;
}

static std::unordered_set<std::string>& GetUnaryOpTypes() {
if (unary_op_types.empty()) {
unary_op_types = OperationMap::Instance().Find(0, 1);
unary_op_types =
OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 1);
}
return unary_op_types;
}

static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types,
Node* n) {
const Node* n) {
if (n && n->IsOp() && n->Op() && n->outputs.size() > 0U) {
auto iter = op_types.find(n->Op()->Type());
if (iter != op_types.end()) {
Expand All @@ -50,7 +53,7 @@ static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types,
return false;
}

static bool IsGradOp(Node* n) {
static bool IsGradOp(const Node* n) {
PADDLE_ENFORCE_EQ(n && n->IsOp() && n->Op(), true,
"Node %p should be a op node.", n);
std::string suffix = "_grad";
Expand All @@ -73,9 +76,9 @@ static bool IsEqual(const std::vector<int64_t>& l,
return true;
}

static bool IsBinaryOp(Node* n, bool backward) {
if (IsSpecifiedOp(GetBinaryOpTypes(), n) && (IsGradOp(n) == backward)) {
if ((!backward && n->inputs.size() != 2U) || n->inputs.size() == 0U) {
static bool IsBinaryOp(const Node* n) {
if (IsSpecifiedOp(GetBinaryOpTypes(), n)) {
if ((!IsGradOp(n) && n->inputs.size() != 2U) || n->inputs.size() == 0U) {
return false;
}

Expand All @@ -101,133 +104,19 @@ static bool IsBinaryOp(Node* n, bool backward) {
return false;
}

static bool IsUnaryOp(Node* n, bool backward) {
return IsSpecifiedOp(GetUnaryOpTypes(), n) && (IsGradOp(n) == backward);
static bool IsUnaryOp(const Node* n) {
return IsSpecifiedOp(GetUnaryOpTypes(), n);
}

void ElementwiseGroupDetector::Init(Graph* graph, bool backward) {
graph_ = graph;
backward_ = backward;
for (auto* n : graph_->Nodes()) {
if (IsBinaryOp(n, backward) || IsUnaryOp(n, backward)) {
elementwise_ops_.insert(n);
}
}
LOG(INFO) << "elementise ops for graph:" << graph
<< ", backward=" << backward;
LOG(INFO) << "{\n" << DebugString(elementwise_ops_) << "}\n";
bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
return IsBinaryOp(n) || IsUnaryOp(n);
}

bool ElementwiseGroupDetector::IsElementwiseOp(Node* n) {
if (n && n->IsOp() && n->Op()) {
return elementwise_ops_.find(n) != elementwise_ops_.end();
} else {
return false;
}
}
std::vector<std::vector<Node*>> ElementwiseGroupDetector::operator()(
Graph* graph) {
auto teller = [&](const Node* n) -> bool { return IsElementwiseOp(n); };

bool ElementwiseGroupDetector::IsInputOfElementwiseOp(Node* n,
std::string name) {
if (n && n->IsVar() && n->Var()) {
for (auto* op : n->outputs) {
if (IsElementwiseOp(op)) {
if (name.empty()) {
return true;
} else {
auto var_name = op->Op()->Input(name);
if (var_name.size() == 1U && var_name[0] == n->Name()) {
return true;
}
}
}
}
}
return false;
}

bool ElementwiseGroupDetector::IsOutputOfElementwiseOp(Node* n) {
if (n && n->IsVar() && n->Var()) {
for (auto* op : n->inputs) {
if (IsElementwiseOp(op)) {
return true;
}
}
}
return false;
}

int ElementwiseGroupDetector::Search(Node* n, std::vector<Node*> except_nodes,
SubGraph* subgraph) {
std::unordered_set<Node*> except_nodes_set;
for (size_t i = 0; i < except_nodes.size(); ++i) {
except_nodes_set.insert(except_nodes[i]);
}

auto search_op_handler = [&](Node* n, Node* var) -> int {
// n, is a op node.
// var, is n's input or output var node.
int num_operations = 0;
if (var && var->IsVar() && var->Var() && !subgraph->Has(var)) {
subgraph->Insert(var);
if (except_nodes_set.find(var) == except_nodes_set.end()) {
num_operations = Search(var, {n}, subgraph);
}
}
return num_operations;
};

auto search_var_handler = [&](Node* n, Node* op) -> int {
// n, is a var node.
// op, is n's input or output op node.
int num_operations = 0;
if (IsElementwiseOp(op) &&
except_nodes_set.find(op) == except_nodes_set.end() &&
!subgraph->Has(op)) {
num_operations = Search(op, {n}, subgraph);
}
return num_operations;
};

int num_operations = 0;
if (IsElementwiseOp(n)) {
// LOG(INFO) << "Search[begin]:" << n->Op()->Type();
subgraph->Insert(n);
num_operations += 1;
for (auto* var : n->inputs) {
num_operations += search_op_handler(n, var);
}
for (auto* var : n->outputs) {
num_operations += search_op_handler(n, var);
}
} else if (n && n->IsVar() && n->Var()) {
// LOG(INFO) << "Search[begin]:" << n->Name();
for (auto* op : n->inputs) {
num_operations += search_var_handler(n, op);
}
for (auto* op : n->outputs) {
num_operations += search_var_handler(n, op);
}
}
return num_operations;
}

SubGraph ElementwiseGroupDetector::operator()(Node* n) {
SubGraph subgraph(0);
if (n && n->IsVar() && n->Var() && n->Name() == "split_1.tmp_0") {
LOG(INFO) << DebugString(n);
for (auto* out : n->outputs) {
LOG(INFO) << DebugString(out);
}
}
if (!IsOutputOfElementwiseOp(n) && IsInputOfElementwiseOp(n, "X")) {
LOG(INFO) << "Begin with node:" << n->Name() << ", backward:" << backward_;
subgraph.Insert(n);
int num_operations = Search(n, n->inputs, &subgraph);
VLOG(3) << "Detect elementwise subgraph begin with " << n->Name() << ", "
<< num_operations << " operations, " << subgraph.GetNumNodes()
<< " nodes";
}
return subgraph;
return SubgraphDetector(graph, teller)();
}

} // namespace fusion_group
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ limitations under the License. */

#pragma once

#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/fusion_group/subgraph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"

Expand All @@ -28,25 +25,10 @@ namespace fusion_group {

class ElementwiseGroupDetector {
public:
explicit ElementwiseGroupDetector(Graph* graph, bool backward) {
Init(graph, backward);
}

SubGraph operator()(Node* n);

private:
void Init(Graph* graph, bool backward);

bool IsElementwiseOp(Node* n);
bool IsInputOfElementwiseOp(Node* n, std::string name = "");
bool IsOutputOfElementwiseOp(Node* n);

int Search(Node* n, std::vector<Node*> except_nodes, SubGraph* subgraph);
std::vector<std::vector<Node*>> operator()(Graph* graph);

private:
Graph* graph_{nullptr}; // Not owned
bool backward_{false};
std::unordered_set<Node*> elementwise_ops_;
bool IsElementwiseOp(const Node* n);
};

} // namespace fusion_group
Expand Down
69 changes: 21 additions & 48 deletions paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/platform/device_code.h"

namespace paddle {
namespace framework {
namespace ir {

void FusionGroupPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph);
FusePassBase::Init("fusion_group_pass", graph);
if (Get<bool>("use_gpu")) {
fusion_group::OperationMap::Init();
int num_elementwise_groups = DetectFusionGroup(graph, 0);
Expand All @@ -37,58 +38,30 @@ void FusionGroupPass::ApplyImpl(ir::Graph* graph) const {
}

int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
std::vector<fusion_group::SubGraph> subgraphs;
std::vector<Node*> begin_of_forward_subgraph;

// Detect subgraph of forward ops.
fusion_group::ElementwiseGroupDetector forward_detector(graph, false);
std::unordered_set<Node*> all_nodes = graph->Nodes();
// TODO(liuyiqun): supported different places
platform::CUDAPlace place = platform::CUDAPlace(0);
int index = platform::DeviceCodePool::Init({place}).size(place);
for (Node* n : all_nodes) {
bool is_found = false;
for (auto& subgraph : subgraphs) {
if (subgraph.Has(n)) {
is_found = true;
break;
}
}
if (is_found) {
continue;
}

if (type == 0) {
fusion_group::SubGraph subgraph = forward_detector(n);
if (subgraph.GetNumOperations() >= 2) {
std::string func_name = "fused_elementwise_" + std::to_string(index++);
subgraph.SetFuncName(func_name);
subgraphs.push_back(subgraph);
LOG(INFO) << "subgraph: {\n" << DebugString(subgraph.Nodes()) << "}\n";
begin_of_forward_subgraph.push_back(n);
}
std::vector<std::vector<Node*>> subgraphs =
fusion_group::ElementwiseGroupDetector()(graph);

int num_subgraphs = 0;
size_t min_subgraph_size = 2;
bool save_intermediate_out = true;
for (auto& vec : subgraphs) {
if (vec.size() >= min_subgraph_size) {
std::string func_name = "fused_elementwise_" + std::to_string(index++);
fusion_group::SubGraph subgraph(type, func_name, save_intermediate_out,
vec);
LOG(INFO) << "subgraph: {\n"
<< DebugString(subgraph.SortedNodes()) << "}\n";

GenerateCode(&subgraph);
InsertFusionGroupOp(graph, &subgraph);
num_subgraphs++;
}
}
// Detect subgraph of backward ops.
fusion_group::ElementwiseGroupDetector backward_detector(graph, true);
for (auto* begin : begin_of_forward_subgraph) {
if (type == 0) {
fusion_group::SubGraph subgraph = backward_detector(begin);
if (subgraph.GetNumOperations() >= 2) {
std::string func_name =
"fused_elementwise_grad_" + std::to_string(index++);
subgraph.SetFuncName(func_name);
subgraphs.push_back(subgraph);
}
}
}

// TODO(liuyiqun): check whether there are intersection between subgraphs
for (size_t i = 0; i < subgraphs.size(); ++i) {
// GenerateCode(&subgraphs[i]);
// InsertFusionGroupOp(graph, &subgraphs[i]);
}
return subgraphs.size();
return num_subgraphs;
}

void FusionGroupPass::GenerateCode(fusion_group::SubGraph* subgraph) const {
Expand Down Expand Up @@ -133,7 +106,7 @@ void FusionGroupPass::InsertFusionGroupOp(
op_desc.SetAttr("type", subgraph->GetType());
op_desc.SetAttr("func_name", subgraph->GetFuncName());

auto fusion_group_node = graph->CreateOpNode(&op_desc);
Node* fusion_group_node = graph->CreateOpNode(&op_desc);
for (auto* in : input_vars_of_subgraph) {
IR_NODE_LINK_TO(in, fusion_group_node);
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ limitations under the License. */

#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/fusion_group/subgraph.h"
#include "paddle/fluid/framework/ir/pass.h"

namespace paddle {
namespace framework {
namespace ir {

class FusionGroupPass : public Pass {
class FusionGroupPass : public FusePassBase {
protected:
void ApplyImpl(Graph* graph) const override;

Expand Down
Loading

0 comments on commit fc1ab12

Please sign in to comment.