-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add complex support for allgather,diag,eye,gather,lookup_table_v2 #62764
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
还有,这些注册了复数kernel的应该在对应的python api的地方的类型校验以及英文文档处增加一个说明也支持complex
self.outputs = {'Out': np_out.reshape((2, 4, 64))} | ||
self.attrs = {'start_index': self.start_index} | ||
|
||
if core.is_compiled_with_xpu(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个xpu没必要吧,应该也没支持
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,我把这个删掉
@@ -26,5 +27,9 @@ | |||
|
|||
TestCEmbeddingOpFP32() | |||
|
|||
TestCEmbeddingOpComplex64() | |||
|
|||
TestCEmbeddingOpComplex64() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么要跑两次啊
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里写错了,第二个应该是complex128
@@ -247,7 +247,7 @@ void GatherV2GradFunction(const phi::CPUContext& ctx, | |||
auto* out_data = ctx.Alloc<T>(out); | |||
auto out_dim = out->dims(); | |||
int64_t out_index_dim_size = out_dim[axis_index]; | |||
phi::funcs::set_constant(ctx, out, static_cast<T>(0.0)); | |||
phi::funcs::set_constant(ctx, out, static_cast<float>(0.0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么这里会特别指定float,是因为paddle定义的complex 和 c++ 对应的complex 不一致导致不能cast吗,如果是的可以用这种方式using MT = typename phi::dtype::MPTypeTrait::Type; 转成c++ 的type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set_constant的第三个参数是float类型的,在不注册complex的情况下static_cast(0.0)应该会隐士转化为float;注册complex后不能隐式转换,编译会出错,这里索性就直接改成了float。using MT = typename phi::dtype::MPTypeTrait::Type
这个方式下的phi::dtype::complex并没有转为c++type,映射后的应该还是paddle自己实现的复数类型,不是c++标准库里的complex。然后如果添加这些
template <typename T>
class MPTypeTrait<phi::dtype::complex<T>> {
public:
using Type = std::complex<T>;
};
把phi::dtype::complex映射成std::complex,目前来看会对MatrixReduceSumFunctor造成影响,里面使用的__shfl_down_sync洗牌函数不支持std::complex,导致编译错误。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那这里就直接用float吧,注释一下这个set_constant 函数,就只支持float类型value的输入
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Sorry to inform you that af60380's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
PR Category
Others
PR Types
New features
Description
add complex support for allgather,diag,eye,gather,lookup_table_v2