-
Notifications
You must be signed in to change notification settings - Fork 34
【Hackathon 9th No.127】feat: Implement torch._C._nn.linear to torch.nn.functional.linear conversion #353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for your contribution! |
b4c0103 to
2116c66
Compare
|
修改 log2json.py 是为了确保 JSON 报告包含完整且正确的状态和性能数据 |
| if isinstance(node.func, ast.Attribute): | ||
| # node.func.attr should be "linear" | ||
| if node.func.attr == "linear": | ||
| # node.func.value should be torch._C._nn | ||
| if isinstance(node.func.value, ast.Attribute): | ||
| # node.func.value.attr should be "_nn" | ||
| if node.func.value.attr == "_nn": | ||
| # node.func.value.value should be torch._C | ||
| if isinstance(node.func.value.value, ast.Attribute): | ||
| # node.func.value.value.attr should be "_C" | ||
| if node.func.value.value.attr == "_C": | ||
| # node.func.value.value.value should be torch | ||
| if ( | ||
| isinstance( | ||
| node.func.value.value.value, | ||
| ast.Name, | ||
| ) | ||
| and node.func.value.value.value.id | ||
| == "torch" | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
按如下方法改造:
filtered_nodes = [
node
for node in [node]
if isinstance(node.func, ast.Attribute)
if node.func.attr == "linear"
if isinstance(node.func.value, ast.Attribute)
if node.func.value.attr == "_nn"
if isinstance(node.func.value.value, ast.Attribute)
if node.func.value.value.attr == "_C"
if isinstance(node.func.value.value.value, ast.Name)
if node.func.value.value.value.id == "torch"
]
if len(filtered_nodes) > 0:
...| gm.forward = types.MethodType(forward_func, gm) | ||
|
|
||
| # Update _code attribute so that gm.code returns the modified code | ||
| gm._code = new_code |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这么做是很不正式的做法。我们单独准备了 fx_graph_serialize 机制。具体你看#352 pr 里的https://github.com/PaddlePaddle/GraphNet/pull/352/files#diff-cdf8f6acc9f3a2d0573129027aaee53e37576cdcd576205d8e0b63c740003105 这里
| gm.forward = types.MethodType(forward_func, gm) | ||
|
|
||
| # Update _code attribute so that gm.code returns the modified code | ||
| gm._code = new_code |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这种直接改gm._code的做法是很不地道的,应当避免。
| # replace this line with modification code for task 123 (torch._C._nn.pad) | ||
| # replace this line with modification code for task 125 (torch._C._nn.gelu) | ||
| # replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention) | ||
| # replace this line with modification code for task 127 (torch._C._nn.linear) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
把上面的改动放到这里,减少其他 PR 的合入冲突
| # Use serialized code to check for unstable APIs | ||
| graph_text = serialize_graph_module_to_str(gm) | ||
| # Use code to check for unstable APIs | ||
| graph_text = gm.code |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分的代码不能改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check_unstable_to_stable 是裁判,裁判的逻辑不可以改。上面_impl_unstable_to_stable_linear_to_functional_linear是运动员。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到
| # Use serialize_graph_module_to_str to get the serialized code | ||
| # This ensures the code is properly serialized with unstable API replacements | ||
| serialized_code = serialize_graph_module_to_str(gm) | ||
| gm._code = serialized_code |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这代码能去掉吗?
|
|
||
| # replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention) | ||
|
|
||
| # replace this line with modification code for task 127 (torch._C._nn.linear) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
把_impl_unstable_to_stable_linear_to_functional_linear函数的实现挪到这里,避开与其他 pr 的冲突
d49878a to
1a77a0e
Compare
- Implement direct node.target modification for API conversion - Use serialize_graph_module_to_str for API check in check_unstable_api - Add AST-based replacement function (commented) in fx_graph_serialize_util.py - Fix log2json.py to properly initialize result field and map speedup data - Simplify conversion logic by removing complex AST code - Tested with 50 samples: 100% success rate, ES(-6) = 1.013
1a77a0e to
3a3b149
Compare


PR Category
Description