-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PIR API adaptor No.35、40】 Migrate paddle.nn.ChannelShuffle/ClipGradByNorm into pir #59718
Conversation
… my-cool-stuff
你的PR提交成功,感谢你对开源项目的贡献! |
需要pre-commit 处理下代码风格的问题 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice work ~
但还有些地方没有适配:
ClipGradByNorm : 需要适配 class ClipGradByNorm,为其增添 _pir_clip
方法。具体可参考:ClipGradByGlobalNorm。适配了 class ClipGradByNorm 后,再辛苦迁移一下 test/legacy_test/test_gradient_clip.py 文件内的相关单测
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
def test_check_grad(self): | ||
self.check_grad(['X'], 'Out') | ||
self.check_grad(['X'], 'Out',check_pir=True) | ||
|
||
|
||
class TestChannelLast(TestChannelShuffleOp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
该单侧下的 test_static_graph_functional,test_static_graph_layer 需要用 test_with_pir_api 修饰。run_dygraph 不需要,因为其为动态图单测
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_static_graph_functional 和 test_static_graph_layer 的单测中所有的:
base.default_main_program()
替换为 paddle.static.default_main_program()
即可解决错误
@@ -162,7 +162,7 @@ def test_static_graph_layer(self): | |||
|
|||
np.testing.assert_allclose(res_1[0], out_1_np) | |||
np.testing.assert_allclose(res_2[0], out_2_np) | |||
|
|||
@test_with_pir_api |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@test_with_pir_api |
@@ -200,10 +200,10 @@ def run_dygraph(self, groups, data_format): | |||
if data_format != 'NCHW': | |||
channel_shuffle_str += f', data_format={data_format}' | |||
self.assertEqual(channel_shuffle.extra_repr(), channel_shuffle_str) | |||
|
|||
@test_with_pir_api |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@test_with_pir_api |
def test_dygraph1(self): | ||
self.run_dygraph(3, "NCHW") | ||
|
||
@test_with_pir_api |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@test_with_pir_api |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
该文件下的 TestChannelShuffleError 是否支持 pir 模式?若不支持麻烦在 pr 描述里说明,方便我们后续 fix
PR types
Others
PR changes
APIs
Description
PIR API 推全升级
paddle.nn.ClipGradByNorm
引用了clip_by_norm,对clip_by_norm迁移升级至 pir,并更新单测,test_clip_by_norm_op单测覆盖率:4/4 test_gradient_clip 单测覆盖率:0/2paddle.nn.ChannelShuffle
引用了paddle.nn.functional.channel_shuffle,对channel_shuffle迁移升级至 pir,并更新单测, 单测覆盖率:6/6(TestChannelShuffleError 通过)test_gradient_clip报错: