Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon】61. segment_pool 算子 fp16/bf16 完善 #53785

Merged
merged 1 commit into from
May 18, 2023

Conversation

co63oc
Copy link
Contributor

@co63oc co63oc commented May 13, 2023

PR types

Others

PR changes

Others

Description

segment_pool 算子 fp16/bf16 完善
图片

@paddle-bot
Copy link

paddle-bot bot commented May 13, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added contributor External developers status: proposed labels May 13, 2023
@paddle-bot
Copy link

paddle-bot bot commented May 13, 2023

❌ The PR is not created using PR's template. You can refer to this Demo.
Please use PR's template, it helps save our maintainers' time so that more developers get helped.

@ZHUI ZHUI requested a review from DesmonDay May 15, 2023 08:38
DesmonDay
DesmonDay previously approved these changes May 15, 2023
Copy link
Contributor

@DesmonDay DesmonDay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -153,10 +166,14 @@ def setUp(self):
x, segment_ids = self.set_data()
result, self.gradient = self.compute(x, segment_ids)
self.inputs = {
'X': x.astype(self.dtype),
'X': x,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

继承了TestSegmentOps,setUp这部分是不是可以删掉了,都是重复代码

Copy link
Contributor Author

@co63oc co63oc May 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -445,6 +445,57 @@ CUDA_ATOMIC_WRAPPER(Max, phi::dtype::float16) {
}
#endif

inline static __device__ uint32_t max_to_low_half_bf16(uint32_t val, float x) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命名方式最好和文件内的其他bf16函数统一,如bf16_max_to_low_half

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
}

CUDA_ATOMIC_WRAPPER(Max, phi::dtype::bfloat16) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考L248,是不是也需要写一个__nv_bfloat16类型的定义

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

248使用的是atomicAdd,没有atomicMax,atomicMin函数可以调用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@ZzSean ZzSean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZzSean ZzSean merged commit 0bed220 into PaddlePaddle:develop May 18, 2023
@co63oc co63oc deleted the segment branch May 18, 2023 08:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants