Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim][NewIR] Support forward decomposition in new IR #55480

Merged
merged 7 commits into from
Aug 8, 2023

Conversation

cyber-pioneer
Copy link
Contributor

@cyber-pioneer cyber-pioneer commented Jul 17, 2023

PR types

New features

PR changes

Others

Description

Pcard-66975
To support forward decomposition in New IR.
The process is as follows:
image

Note:
This pr finishes base framework of forward decomposition. To finish this feature, about 3 todos case will be done:

  1. To process unused outputs after program is decomposed.
  2. To support recursive call.
  3. To support custom vjp.

@paddle-bot
Copy link

paddle-bot bot commented Jul 17, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@cyber-pioneer cyber-pioneer changed the title [No merge] test add infer var_type [No merge][Prim] Add forward prim in new ir Jul 24, 2023
map_output_for_composite,
prepare_python_api_arguments,
)

Copy link
Contributor

@cxxly cxxly Jul 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

follow pep8 codestyle
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

block,
blacklist=frozenset(),
whitelist=frozenset(),
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shoule be program

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

blacklist = prim_config["forward_blacklist"] | blacklist

with framework.program_guard(main_program):
_prim_logger.info("Decompose composite forward ops begin...")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use debug level for logging readonly info.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


_prim_logger = get_logger(
"prim", logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

follow python logging best pratices, use module-level logger https://docs.python.org/3/howto/logging.html
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done



def _decom_execute(block, op_filter):
"""The operators in block wich satisfy the filter conditon will be decomposed into primitives."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decom_execute is still unreadable,consider how to spilit into different function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@cyber-pioneer cyber-pioneer changed the title [No merge][Prim] Add forward prim in new ir [No merge][Prim][NewIR] Support forward decomposition in new IR Aug 1, 2023
@cyber-pioneer cyber-pioneer reopened this Aug 3, 2023
import typing

from paddle import ir
from paddle.fluid import core
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fluid 逻辑已经迁移,使用paddle.framework.core

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

from paddle.fluid import core
from paddle.fluid.libpaddle.ir import Block, Program

from .utils import get_decomp_rule
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

遵循pep8 import 规范, 非必要情况不直接import 函数/类,import模块,并通过xx.xx形式访问

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

logging.debug(f"Decompose composite forward ops finish: {replace_ops}")


def _decompose_subgraph(block, op_filter):
Copy link
Contributor

@cxxly cxxly Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加注释解释下 op_filter 功能,或者使用type_hints注明 filter是函数类型

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


if lower:
change = True
core.prim_config["composite_ops_record"].add(op_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意到这个信息只是给记录日志使用,是否做成无状态更好,避免使用全局变量

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decompose会在程序里调用多次,需要用使用全局变量记录所有已拆解算子的信息

for item in block:
_decompose_subgraph(item, op_filter)
return
raise TypeError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

异常需要包含关键信息

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# limitations under the License.

from .primitive_op import * # noqa: F403
from .utils import register_decomp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

遵循导入规范

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""define composite rule of op mean"""
x_shape = x.shape
axes = axis or list(range(0, len(x_shape)))
axes = [axes] if isinstance(axes, int) else axes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用tuple

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

def mean_decomp(x, axis, keepdim):
"""define composite rule of op mean"""
x_shape = x.shape
axes = axis or list(range(0, len(x_shape)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用tuple

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

def register(self, op_type, rule):
assert (
op_type not in self.rules
), f'name "{op_type}" should not be registered before.'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用异常检查入参

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@cxxly
Copy link
Contributor

cxxly commented Aug 7, 2023

python中,代码组织通过目录树路径,一般文件中不再需要包名信息,建议文件命名去掉包名前缀

def _decomp(*args):
return f(*args)

_decomposition_ops.register(op_type, _decomp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

移动到外部作用域

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if not isinstance(op_type, str):
raise TypeError(f'op_type must be str, but got {type(op_type)}.')

def wrapper(f):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用functools.wraps,可以避免修改函数doc string

Copy link
Contributor Author

@cyber-pioneer cyber-pioneer Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改修饰器,functools.wraps此处用不上


@register_decomp('pd.mean')
def mean_decomp(x, axis, keepdim):
"""define composite rule of op mean"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议去掉decomp后缀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@cyber-pioneer cyber-pioneer changed the title [No merge][Prim][NewIR] Support forward decomposition in new IR [Prim][NewIR] Support forward decomposition in new IR Aug 7, 2023
@cyber-pioneer
Copy link
Contributor Author

python中,代码组织通过目录树路径,一般文件中不再需要包名信息,建议文件命名去掉包名前缀

done

Copy link
Contributor

@risemeup1 risemeup1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@cyber-pioneer cyber-pioneer merged commit 523916f into PaddlePaddle:develop Aug 8, 2023
27 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants