-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
GeneratePass for Python Pass (#35708)
#### 背景 #35602 提供Python侧开发子图替换类Pass的方式: - 利用Paddle Python API或者辅助类型定义子图program用来匹配/替换图; - Python侧注册Pass时,将注册函数最终转换为protobuf定义的PassDesc数据形式,供C++侧进行解析完成Pass实例注册。 本PR即为根据PassDesc规则描述解析生成Pass实例。 #### 方案设计 ##### Pass规则验证 在以往的Pass开发中,会存在随着算子迭代引发的匹配失效或者错误匹配的问题,该问题可以通过扫描算子支持的参数设置及参数类型等来判断是否应该使用该Pass或者给出提示需要修改Pass代码。 当前Pass开发中提供了算子兼容性OpCompatSensiblePass用于解决上述问题。但同时还存在不足:由于以往Pass开发在运行时才能获取到pattern信息,所以需要在执行Pass时才可以判断。 使用PassDesc表示的Pass可以在执行Pass前验证上述问题,这个过程在VerifyDesc中完成。 ##### 根据匹配子图构造pattern GeneratePass对于图匹配和替换使用GraphPatternDecetor完成,构造匹配pattern实际上就是将对应对象成员PDPattern中添加PDNode和边关系。该过程在函数`InitGeneratePattern`中完成,该函数没有作为GeneratePass的成员方法,主要出于后续可能开发新的Decetor考虑,GeneratePass与Decetor的操作是没有关联的。 初始化pattern主要通过遍历匹配子图program的全部算子实现: 1. 添加当前算子对应PDNode及限制条件(算子类型、属性限制等); 2. 遍历当前算子对应输入并从pattern中尝试获取PDNode: - 在pattern中获取到PDNode且为输出节点:表示属于匹配子图的中间节点,将该PDNode设置为中间节点; - 在pattern中没有获取到PDNode:添加该输入PDNode并设置作为输入节点; - 设置输入到算子的边关系; 3. 遍历当前算子对应输出: - 在pattern中获取到PDNode且为输入节点:表示属于匹配子图的中间节点,将该PDNode设置为中间节点; - 在pattern中没有获取到PDNode:添加该输入PDNode并设置作为输出节点; - 设置算子到输出的边关系; ##### 根据替换子图操作graph 替换子图操作的过程在`GetGenerateRewrite`函数中完成,与`InitGeneratePattern`类似没有作为GeneratePass的成员方法。 生成替换子图操作过程如下: 1. 判断冗余替换子图; 2. 遍历替换子图program的全部算子添加替换子图Node: 1. 添加当前算子的Node及属性设置; 2. 遍历当前算子对应输入,添加中间variable节点; 3. 遍历当前算子对应输出,添加中间variable节点; 4. 添加输入/输出节点与算子节点的边关系; 3. 删除匹配图中属于中间节点的Node; ##### 优化子图验证 对于替换子图或者替换后的计算图是否可以正确运行等,可以在执行Pass时验证,从而防止在后续执行计算图时出现异常。 当前Pass执行直接修改计算图,验证失败时无法很好的完成还原操作,目前子图验证暂时默认成功,留到后续改进。
- Loading branch information
Showing
7 changed files
with
674 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/ir/generate_pass.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
namespace ir { | ||
|
||
void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { | ||
const proto::BlockDesc& block = pass_desc.pattern().blocks(0); | ||
// Traverse all operators to create subgraph. | ||
for (int index = 0; index < block.ops_size(); ++index) { | ||
const proto::OpDesc& op = block.ops(index); | ||
// Create a PDNode for current operator. Use the index as name to avoid | ||
// multiple operators with same type. Get a PDNode from pattern subgraph | ||
// through index in rewrite phase. | ||
PDNode* op_pdnode = | ||
pattern->NewNode(std::to_string(index))->assert_is_op(op.type()); | ||
// Create PDNodes for inputs of current operator. | ||
for (const proto::OpDesc::Var& var : op.inputs()) { | ||
for (const std::string& argument : var.arguments()) { | ||
// The input may be the output of other operator. | ||
PDNode* var_pdnode = pattern->RetrieveNode(argument); | ||
if (nullptr == var_pdnode) { | ||
var_pdnode = pattern->NewNode(argument)->AsInput(); | ||
} else if (var_pdnode->IsOutput()) { | ||
var_pdnode->AsIntermediate(); | ||
} | ||
var_pdnode->assert_is_op_input(op.type()); | ||
pattern->AddEdge(var_pdnode, op_pdnode); | ||
} | ||
} | ||
// Create PDNodes for outputs of current operator. | ||
for (const proto::OpDesc::Var& var : op.outputs()) { | ||
for (const std::string& argument : var.arguments()) { | ||
// The output may be the input of other operator. | ||
PDNode* var_pdnode = pattern->RetrieveNode(argument); | ||
if (nullptr == var_pdnode) { | ||
var_pdnode = pattern->NewNode(argument)->AsOutput(); | ||
} else if (var_pdnode->IsInput()) { | ||
var_pdnode->AsIntermediate(); | ||
} | ||
var_pdnode->assert_is_op_output(op.type()); | ||
pattern->AddEdge(op_pdnode, var_pdnode); | ||
} | ||
} | ||
// Set attribute condition for current operator. | ||
for (const proto::OpDesc::Attr& attr : op.attrs()) { | ||
op_pdnode->assert_more([&](Node* x) { | ||
if (x && x->IsOp()) { | ||
OpDesc* op_desc = x->Op(); | ||
if (op_desc->HasAttr(attr.name())) { | ||
return GetAttrValue(attr) == op_desc->GetAttr(attr.name()); | ||
} | ||
return false; | ||
} | ||
return false; | ||
}); | ||
} | ||
} | ||
} | ||
|
||
GraphPatternDetector::handle_t GetGenerateRewrite( | ||
const PDPattern& pattern, const proto::PassDesc& pass_desc) { | ||
GraphPatternDetector::handle_t handler = [&]( | ||
const GraphPatternDetector::subgraph_t subgraph, Graph* graph) { | ||
// There are some duplicate patterns. | ||
for (auto iter : subgraph) { | ||
if (nullptr == graph->RetrieveNode(iter.second->id())) { | ||
VLOG(3) << "Node [" << iter.second->Name() | ||
<< "] of subgraph has been removed. So skip this optimize."; | ||
return; | ||
} | ||
} | ||
const proto::BlockDesc& block = pass_desc.replace().blocks(0); | ||
// `var_node_maps` record the mapping of variable to the pattern subgraph. | ||
std::map<std::string, Node*> var_node_maps; | ||
for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) { | ||
Node* node = subgraph.at(pattern.RetrieveNode(var_map.pattern_var())); | ||
var_node_maps.insert({var_map.replace_var(), node}); | ||
} | ||
// Traverse all operators to create subgraph. | ||
for (const proto::OpDesc& op : block.ops()) { | ||
OpDesc op_desc; | ||
std::vector<Node *> in_nodes, out_nodes; | ||
op_desc.SetType(op.type()); | ||
// Create Nodes for inputs of current operator. | ||
for (const proto::OpDesc::Var& var : op.inputs()) { | ||
std::vector<std::string> arguments; | ||
for (const std::string& argument : var.arguments()) { | ||
// The input may be mapped on the operator of pattern subgraph. | ||
Node* node = nullptr; | ||
auto iter = var_node_maps.find(argument); | ||
if (var_node_maps.end() == iter) { | ||
VarDesc var_desc(patterns::UniqueKey(argument)); | ||
node = graph->CreateVarNode(&var_desc); | ||
var_node_maps.insert({argument, node}); | ||
} else { | ||
node = iter->second; | ||
} | ||
in_nodes.push_back(node); | ||
arguments.push_back(node->Name()); | ||
} | ||
op_desc.SetInput(var.parameter(), arguments); | ||
} | ||
// Create Nodes for outputs of current operator. | ||
for (const proto::OpDesc::Var& var : op.outputs()) { | ||
std::vector<std::string> arguments; | ||
for (const std::string& argument : var.arguments()) { | ||
// The output may be mapped on the operator of pattern subgraph. | ||
Node* node = nullptr; | ||
auto iter = var_node_maps.find(argument); | ||
if (var_node_maps.end() == iter) { | ||
VarDesc var_desc(patterns::UniqueKey(argument)); | ||
node = graph->CreateVarNode(&var_desc); | ||
var_node_maps.insert({argument, node}); | ||
} else { | ||
node = iter->second; | ||
} | ||
out_nodes.push_back(node); | ||
arguments.push_back(node->Name()); | ||
} | ||
op_desc.SetOutput(var.parameter(), arguments); | ||
} | ||
// Set attribute for current operator. | ||
for (const proto::OpDesc::Attr& attr : op.attrs()) { | ||
op_desc.SetAttr(attr.name(), GetAttrValue(attr)); | ||
} | ||
// Create a Node for current operator. | ||
Node* op_node = graph->CreateOpNode(&op_desc); | ||
for (Node* node : in_nodes) { | ||
IR_NODE_LINK_TO(node, op_node); | ||
} | ||
for (Node* node : out_nodes) { | ||
IR_NODE_LINK_TO(op_node, node); | ||
} | ||
} | ||
// Remove nodes that are intermediate. | ||
std::unordered_set<const Node*> remove_nodes; | ||
for (const std::unique_ptr<PDNode>& pdnode : pattern.nodes()) { | ||
remove_nodes.emplace(subgraph.at(pdnode.get())); | ||
} | ||
for (auto iter : var_node_maps) { | ||
remove_nodes.erase(iter.second); | ||
} | ||
GraphSafeRemoveNodes(graph, remove_nodes); | ||
}; | ||
return handler; | ||
} | ||
|
||
GeneratePass::GeneratePass(const std::string& binary_str) { | ||
multi_pass_desc_.ParseFromString(binary_str); | ||
VerifyDesc(); | ||
} | ||
|
||
GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc) | ||
: multi_pass_desc_(multi_pass_desc) { | ||
VerifyDesc(); | ||
} | ||
|
||
void GeneratePass::ApplyImpl(Graph* graph) const { | ||
for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) { | ||
GraphPatternDetector detector; | ||
InitGeneratePattern(pass_desc, detector.mutable_pattern()); | ||
detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc)); | ||
// The rewrited graph needs to be verified. Current Pass should be skipped | ||
// if validation failed. Rewrite based on the original graph cannot | ||
// implement rollback operation. | ||
VerifyGraph(*graph); | ||
} | ||
} | ||
|
||
void GeneratePass::VerifyDesc() const { | ||
PADDLE_ENFORCE_NE(multi_pass_desc_.pass_descs_size(), 0, | ||
platform::errors::InvalidArgument( | ||
"Size of PassDesc should not be empty.")); | ||
for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) { | ||
// Check inputs/outputs of subgraph should in `var_maps`. | ||
std::set<std::string> pattern_var_sets, replace_var_sets; | ||
for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) { | ||
pattern_var_sets.emplace(var_map.pattern_var()); | ||
replace_var_sets.emplace(var_map.replace_var()); | ||
} | ||
auto check_vars = [=](std::set<std::string>* var_sets, | ||
const proto::BlockDesc& block) { | ||
for (const proto::OpDesc& op : block.ops()) { | ||
for (const proto::OpDesc::Var& var : op.outputs()) { | ||
for (const std::string& argument : var.arguments()) { | ||
var_sets->emplace(argument); | ||
} | ||
} | ||
} | ||
for (const proto::OpDesc& op : block.ops()) { | ||
for (const proto::OpDesc::Var& var : op.inputs()) { | ||
for (const std::string& argument : var.arguments()) { | ||
PADDLE_ENFORCE_NE( | ||
var_sets->find(argument), var_sets->end(), | ||
platform::errors::InvalidArgument( | ||
"Subgraph of PassDesc has argument [%s] not in `var_maps`.", | ||
argument)); | ||
} | ||
} | ||
} | ||
}; | ||
check_vars(&pattern_var_sets, pass_desc.pattern().blocks(0)); | ||
check_vars(&replace_var_sets, pass_desc.replace().blocks(0)); | ||
} | ||
} | ||
|
||
bool GeneratePass::VerifyGraph(const Graph& graph) { | ||
// Return true temporarily. | ||
return true; | ||
} | ||
|
||
} // namespace ir | ||
} // namespace framework | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" | ||
#include "paddle/fluid/framework/ir/pass.h" | ||
#include "paddle/fluid/framework/pass_desc.pb.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
namespace ir { | ||
|
||
// Generate a substitute pass from protobuf. | ||
class GeneratePass : public Pass { | ||
public: | ||
// from binary_str | ||
explicit GeneratePass(const std::string& binary_str); | ||
// from PassDesc/MultiPassDesc | ||
explicit GeneratePass(const proto::MultiPassDesc& multi_pass_desc); | ||
|
||
protected: | ||
void ApplyImpl(Graph* graph) const override; | ||
|
||
private: | ||
GeneratePass() = delete; | ||
DISABLE_COPY_AND_ASSIGN(GeneratePass); | ||
// Verify desc | ||
void VerifyDesc() const; | ||
// Verify graph | ||
static bool VerifyGraph(const Graph& graph); | ||
|
||
proto::MultiPassDesc multi_pass_desc_; | ||
}; | ||
|
||
} // namespace ir | ||
} // namespace framework | ||
} // namespace paddle |
Oops, something went wrong.