-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon】61. segment_pool 算子 fp16/bf16 完善 #53785
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
❌ The PR is not created using PR's template. You can refer to this Demo. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -153,10 +166,14 @@ def setUp(self): | |||
x, segment_ids = self.set_data() | |||
result, self.gradient = self.compute(x, segment_ids) | |||
self.inputs = { | |||
'X': x.astype(self.dtype), | |||
'X': x, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
继承了TestSegmentOps,setUp这部分是不是可以删掉了,都是重复代码
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
@@ -445,6 +445,57 @@ CUDA_ATOMIC_WRAPPER(Max, phi::dtype::float16) { | |||
} | |||
#endif | |||
|
|||
inline static __device__ uint32_t max_to_low_half_bf16(uint32_t val, float x) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名方式最好和文件内的其他bf16函数统一,如bf16_max_to_low_half
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16); | ||
} | ||
|
||
CUDA_ATOMIC_WRAPPER(Max, phi::dtype::bfloat16) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考L248,是不是也需要写一个__nv_bfloat16
类型的定义
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
248使用的是atomicAdd,没有atomicMax,atomicMin函数可以调用
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html atomicMax 没有bfloat16类型
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Description
segment_pool 算子 fp16/bf16 完善