Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python support register pass via PassDesc #35602

Merged
merged 5 commits into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ if(WITH_PYTHON)
py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto)
py_proto_compile(trainer_py_proto SRCS trainer_desc.proto data_feed.proto)
py_proto_compile(distributed_strategy_py_proto SRCS distributed_strategy.proto)
py_proto_compile(pass_desc_py_proto SRCS pass_desc.proto)
#Generate an empty \
#__init__.py to make framework_py_proto as a valid python module.
add_custom_target(fleet_proto_init ALL
Expand Down
39 changes: 39 additions & 0 deletions paddle/fluid/framework/pass_desc.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

syntax = "proto2";

import "framework.proto";
package paddle.framework.proto;

// Describes one subsitute subgraph.
message PassDesc {
message VarMap {
required string pattern_var = 1;
required string replace_var = 2;
}
message AttrMap {
required int32 pattern_op_idx = 1;
required int32 replace_op_idx = 2;
required string pattern_name = 3;
required string replace_name = 4;
}
required ProgramDesc pattern = 1;
required ProgramDesc replace = 2;
repeated VarMap var_maps = 3;
repeated AttrMap attr_maps = 4;
zhhsplendid marked this conversation as resolved.
Show resolved Hide resolved
}

// A series of PassDesc.
message MultiPassDesc {
optional string pass_type = 1;
zhhsplendid marked this conversation as resolved.
Show resolved Hide resolved
repeated PassDesc pass_descs = 2;
}
246 changes: 243 additions & 3 deletions python/paddle/fluid/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import copy
from . import core
from .framework import _apply_pass
import inspect
from os import path
import paddle
from . import core, unique_name
from .framework import _apply_pass, OpProtoHolder

try:
zhhsplendid marked this conversation as resolved.
Show resolved Hide resolved
from .proto import pass_desc_pb2
except ModuleNotFoundError:
import sys
sys.path.append(path.join(path.dirname(__file__), 'proto'))
from .proto import pass_desc_pb2


def get_data_vars(program):
Expand Down Expand Up @@ -115,3 +124,234 @@ def apply_pass(name):
build_strategy.enable_inplace = False
build_strategy._clear_finalized()
return build_strategy


class RegisterPassHelper(object):
def __init__(self, pass_pairs, pass_type=str(), input_specs=dict()):
self._pass_type = pass_type
self._pass_pairs = pass_pairs
if isinstance(input_specs, dict):
self._input_specs = input_specs

def _get_args_from_func(self, func):
args = list()
arg_specs = inspect.getfullargspec(func)
for arg_name in arg_specs.args:
input_spec = self._input_specs.get(arg_name)
if isinstance(input_spec, paddle.static.InputSpec):
args.append(
paddle.static.data(arg_name, input_spec.shape,
input_spec.dtype))
elif isinstance(input_spec, paddle.ParamAttr):
args.append(paddle.ParamAttr(arg_name))
else:
args.append(paddle.static.data(arg_name, [-1]))
return args

def _func_to_program_desc(self, func, program_desc, is_replace=False):
vars = list()
program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(program, startup_program):
args = self._get_args_from_func(func)
for arg in args:
vars.append(arg.name)
outs = func(*args)
if not isinstance(outs, (list, tuple)):
outs = [outs]
for out in outs:
if isinstance(out, PassDesc.OpHelper):
for out in out.Outputs().values():
vars.extend(out)
elif isinstance(out, paddle.fluid.framework.Variable):
vars.append(out.name)
program_desc.ParseFromString(program.desc.serialize_to_string())
if is_replace:
attrs = list()
for op in program.current_block().ops:
if not isinstance(op, PassDesc.OpHelper):
continue
attrs.extend(op._attrs.values())
return vars, attrs
return vars

def SerializeMultiPassDesc(self):
switch_static_mode = paddle.in_dynamic_mode()
if switch_static_mode:
paddle.enable_static()
multi_pass_desc = pass_desc_pb2.MultiPassDesc()
multi_pass_desc.pass_type = self._pass_type
for (pattern, replace) in self._pass_pairs:
pass_desc = multi_pass_desc.pass_descs.add()
pattern_vars = self._func_to_program_desc(pattern,
pass_desc.pattern)
replace_vars, attrs = self._func_to_program_desc(
replace, pass_desc.replace, is_replace=True)
for (pattern_var, replace_var) in zip(pattern_vars, replace_vars):
var_map = pass_desc.var_maps.add()
var_map.pattern_var = pattern_var
var_map.replace_var = replace_var
pattern_op_idxs = dict()
for (idx, op) in enumerate(pass_desc.pattern.blocks[0].ops):
op_idxs = pattern_op_idxs.get(op.type)
if op_idxs:
op_idxs.append(idx)
else:
pattern_op_idxs[op.type] = [idx]
for attr in attrs:
attr_map = pass_desc.attr_maps.add()
attr_map.pattern_op_idx = pattern_op_idxs[
attr._pattern_op_type][attr._pattern_op_idx]
attr_map.replace_op_idx = attr._replace_op_idx
attr_map.pattern_name = attr._pattern_name
attr_map.replace_name = attr._replace_name
if switch_static_mode:
paddle.disable_static()
return multi_pass_desc.SerializeToString()


class PassDesc(object):
class AttrHelper(object):
def __init__(self, name, replace_op_idx):
self._pattern_op_type = None
self._pattern_op_idx = -1
self._replace_op_idx = replace_op_idx
self._pattern_name = name
self._replace_name = name

def ReusePattern(self, op, index=0, name=None):
if name:
self._pattern_name = name
self._pattern_op_type = op
self._pattern_op_idx = index

class OpHelper(object):
def __init__(self, type=None):
self._type = type

def __getattr__(self, name):
if self._type is not None:
raise AttributeError(
"type object 'OpHelper' has no attribute '{}'".format(name))
op = PassDesc.OpHelper(name)
op.Init()
return op

def __call__(self, *args, **kwargs):
for (in_name, in_args) in kwargs.items():
in_arg_names = list()
if isinstance(in_args, (list, tuple)):
if len(in_args) == 0:
raise ValueError(
"Input '{}' of operator '{}' cannot be empty.".
format(in_name, self._type))
else:
in_args = [in_args]
for in_arg in in_args:
if isinstance(in_arg, PassDesc.OpHelper):
in_arg_names.extend(in_arg.Output())
else:
in_arg_names.append(in_arg.name)
self._op_desc.set_input(in_name, in_arg_names)
return self

def Init(self):
block = paddle.static.default_main_program().current_block()
self._attrs = dict()
self._op_idx = len(block.ops)
self._op_desc = block.desc.append_op()
self._op_desc.set_type(self._type)
self._op_proto = OpProtoHolder.instance().get_op_proto(self._type)
block.ops.append(self)

def Attr(self, name):
attr = self._attrs.get(name)
if attr:
return attr
attr = PassDesc.AttrHelper(name, self._op_idx)
self._attrs[name] = attr
return attr

def SetAttr(self, name, value):
self._op_desc._set_attr(name, value)

def Output(self, name=None):
if name:
return self.Outputs()[name]
return list(self.Outputs().values())[0]

def Outputs(self):
outputs = self._op_desc.outputs()
if len(outputs) > 0:
return outputs
block = paddle.static.default_main_program().current_block()
for output_proto in self._op_proto.outputs:
name = unique_name.generate(self._type)
block.create_var(name=name)
self._op_desc.set_output(output_proto.name, [name])
return self._op_desc.outputs()

OP = OpHelper()


def RegisterPass(function=None, input_specs=None):
"""
The function decorator of Register Pass. Decorator @RegisterPass handles
the function and register it into a core.Pass instance. Use name of function
as Pass type.

Args:
function (callable): The function with return of callable pair(s) that
represents the pattern subgraph and the replace subgraph.
input_specs (dict[str, InputSpec]|None): Dict of InputSpec to specific the shape/dtype
information of Tensor. Some operators limit the shape and dtype of datas when
create subgraph with Paddle APIs. So user need specify InputSpec of data to
ensure create a correctly subgraph. Of course, this argument is not limited to
matching subgraph. The default is None.

Returns:
callables: Callable pair(s).

Examples:
.. code-block:: python

import paddle
from paddle.fluid.ir import RegisterPass

@RegisterPass
def multi_add_to_addn():
def pattern(x, y, z):
return paddle.add(paddle.add(x, y), z)
def replace(x, y, z):
return paddle.add_n([x, y, z])
return pattern, replace
"""

def _is_pass_pair(check_pair):
if isinstance(check_pair, (list, tuple)):
if len(check_pair) == 2:
if all(map(inspect.isfunction, check_pair)):
return True
return False

def decorated(python_func):
pass_type = python_func.__name__
signature = inspect.signature(python_func)
if len(signature.parameters) > 0:
raise NotImplementedError(
"Pass function with parameter is not supported now.")
elif len(signature.parameters) == 0:
pass_pairs = python_func()
if _is_pass_pair(pass_pairs):
pass_pairs = [pass_pairs]
elif not all(map(_is_pass_pair, pass_pairs)):
zhhsplendid marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"Return value of Pass function must be (callable, callable)."
)
helper = RegisterPassHelper(pass_pairs, pass_type, input_specs)
return python_func

if inspect.isfunction(function):
return decorated(function)

return decorated