Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR API adaptor No.35、40】 Migrate paddle.nn.ChannelShuffle/ClipGradByNorm into pir #59718

Closed
wants to merge 9 commits into from

Conversation

fsczz
Copy link
Contributor

@fsczz fsczz commented Dec 5, 2023

PR types

Others

PR changes

APIs

Description

PIR API 推全升级
paddle.nn.ClipGradByNorm 引用了clip_by_norm,对clip_by_norm迁移升级至 pir,并更新单测,test_clip_by_norm_op单测覆盖率:4/4 test_gradient_clip 单测覆盖率:0/2
paddle.nn.ChannelShuffle 引用了paddle.nn.functional.channel_shuffle,对channel_shuffle迁移升级至 pir,并更新单测, 单测覆盖率:6/6(TestChannelShuffleError 通过)

test_gradient_clip报错:
image

Copy link

paddle-bot bot commented Dec 5, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Dec 5, 2023
@luotao1 luotao1 added the HappyOpenSource 快乐开源活动issue与PR label Dec 6, 2023
@MarioLulab
Copy link
Contributor

#58067

@Aurelius84
Copy link
Contributor

需要pre-commit 处理下代码风格的问题

Copy link
Contributor

@MarioLulab MarioLulab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice work ~
但还有些地方没有适配:
ClipGradByNorm : 需要适配 class ClipGradByNorm,为其增添 _pir_clip 方法。具体可参考:ClipGradByGlobalNorm。适配了 class ClipGradByNorm 后,再辛苦迁移一下 test/legacy_test/test_gradient_clip.py 文件内的相关单测

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
也要打开 check_pir 开关

self.check_output_with_place(place, atol=0.001, check_pir=True)


def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out',check_pir=True)


class TestChannelLast(TestChannelShuffleOp):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该单侧下的 test_static_graph_functional,test_static_graph_layer 需要用 test_with_pir_api 修饰。run_dygraph 不需要,因为其为动态图单测

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
test_static_graph_functional,test_static_graph_layer 打开后报错

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_static_graph_functional 和 test_static_graph_layer 的单测中所有的:
base.default_main_program() 替换为 paddle.static.default_main_program() 即可解决错误

@@ -162,7 +162,7 @@ def test_static_graph_layer(self):

np.testing.assert_allclose(res_1[0], out_1_np)
np.testing.assert_allclose(res_2[0], out_2_np)

@test_with_pir_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@test_with_pir_api

@@ -200,10 +200,10 @@ def run_dygraph(self, groups, data_format):
if data_format != 'NCHW':
channel_shuffle_str += f', data_format={data_format}'
self.assertEqual(channel_shuffle.extra_repr(), channel_shuffle_str)

@test_with_pir_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@test_with_pir_api

def test_dygraph1(self):
self.run_dygraph(3, "NCHW")

@test_with_pir_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@test_with_pir_api

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该文件下的 TestChannelShuffleError 是否支持 pir 模式?若不支持麻烦在 pr 描述里说明,方便我们后续 fix

@fsczz fsczz closed this by deleting the head repository Dec 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource 快乐开源活动issue与PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants