-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Prim][NewIR] Support forward decomposition in new IR #55480
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
python/paddle/primitive/lowering.py
Outdated
map_output_for_composite, | ||
prepare_python_api_arguments, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/primitive/lowering.py
Outdated
block, | ||
blacklist=frozenset(), | ||
whitelist=frozenset(), | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shoule be program
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/primitive/lowering.py
Outdated
blacklist = prim_config["forward_blacklist"] | blacklist | ||
|
||
with framework.program_guard(main_program): | ||
_prim_logger.info("Decompose composite forward ops begin...") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use debug level for logging readonly info.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/primitive/lowering.py
Outdated
|
||
_prim_logger = get_logger( | ||
"prim", logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
follow python logging best pratices, use module-level logger https://docs.python.org/3/howto/logging.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/primitive/lowering.py
Outdated
|
||
|
||
def _decom_execute(block, op_filter): | ||
"""The operators in block wich satisfy the filter conditon will be decomposed into primitives.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
decom_execute
is still unreadable,consider how to spilit into different function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
b745302
to
ac20305
Compare
8535195
to
1c9c228
Compare
10cb89a
to
5ea6561
Compare
import typing | ||
|
||
from paddle import ir | ||
from paddle.fluid import core |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fluid 逻辑已经迁移,使用paddle.framework.core
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
from paddle.fluid import core | ||
from paddle.fluid.libpaddle.ir import Block, Program | ||
|
||
from .utils import get_decomp_rule |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
遵循pep8 import 规范, 非必要情况不直接import 函数/类,import模块,并通过xx.xx形式访问
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
logging.debug(f"Decompose composite forward ops finish: {replace_ops}") | ||
|
||
|
||
def _decompose_subgraph(block, op_filter): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加注释解释下 op_filter 功能,或者使用type_hints注明 filter是函数类型
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
if lower: | ||
change = True | ||
core.prim_config["composite_ops_record"].add(op_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注意到这个信息只是给记录日志使用,是否做成无状态更好,避免使用全局变量
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
decompose会在程序里调用多次,需要用使用全局变量记录所有已拆解算子的信息
for item in block: | ||
_decompose_subgraph(item, op_filter) | ||
return | ||
raise TypeError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
异常需要包含关键信息
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
# limitations under the License. | ||
|
||
from .primitive_op import * # noqa: F403 | ||
from .utils import register_decomp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
遵循导入规范
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
"""define composite rule of op mean""" | ||
x_shape = x.shape | ||
axes = axis or list(range(0, len(x_shape))) | ||
axes = [axes] if isinstance(axes, int) else axes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用tuple
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
def mean_decomp(x, axis, keepdim): | ||
"""define composite rule of op mean""" | ||
x_shape = x.shape | ||
axes = axis or list(range(0, len(x_shape))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用tuple
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/primitive/utils.py
Outdated
def register(self, op_type, rule): | ||
assert ( | ||
op_type not in self.rules | ||
), f'name "{op_type}" should not be registered before.' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用异常检查入参
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python中,代码组织通过目录树路径,一般文件中不再需要包名信息,建议文件命名去掉包名前缀 |
python/paddle/primitive/utils.py
Outdated
def _decomp(*args): | ||
return f(*args) | ||
|
||
_decomposition_ops.register(op_type, _decomp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
移动到外部作用域
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/primitive/utils.py
Outdated
if not isinstance(op_type, str): | ||
raise TypeError(f'op_type must be str, but got {type(op_type)}.') | ||
|
||
def wrapper(f): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用functools.wraps,可以避免修改函数doc string
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改修饰器,functools.wraps此处用不上
|
||
@register_decomp('pd.mean') | ||
def mean_decomp(x, axis, keepdim): | ||
"""define composite rule of op mean""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议去掉decomp后缀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
487510e
to
86ee55e
Compare
done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
Others
Description
Pcard-66975
![image](https://private-user-images.githubusercontent.com/116002591/257503718-a5cf4ef1-5b47-44b7-97e7-ba840f8f1ae4.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjE3MzA4NzIsIm5iZiI6MTcyMTczMDU3MiwicGF0aCI6Ii8xMTYwMDI1OTEvMjU3NTAzNzE4LWE1Y2Y0ZWYxLTViNDctNDRiNy05N2U3LWJhODQwZjhmMWFlNC5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjQwNzIzJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI0MDcyM1QxMDI5MzJaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT01NjcwZTk3MzliMjg3Nzg3ZTdjYTA4ZTA2NmY2MjVlMzgyYzJjYTQ0MGFiNDczZDkzMDM2ODFmYTg1YThjMTFhJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.u6k2KxDtcGARFJj8OxkSyxD0nxUfTLP5EmUUSWoseFw)
To support forward decomposition in New IR.
The process is as follows:
Note:
This pr finishes base framework of forward decomposition. To finish this feature, about 3 todos case will be done: