Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add complex support for allgather,diag,eye,gather,lookup_table_v2 #62764

Merged
merged 5 commits into from
Mar 29, 2024

Conversation

zbt78
Copy link
Contributor

@zbt78 zbt78 commented Mar 15, 2024

PR Category

Others

PR Types

New features

Description

add complex support for allgather,diag,eye,gather,lookup_table_v2

Copy link

paddle-bot bot commented Mar 15, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Mar 15, 2024
@zbt78 zbt78 changed the title add complex support for allgather,diag,eye,gather add complex support for allgather,diag,eye,gather,lookup_table Mar 17, 2024
@zbt78 zbt78 changed the title add complex support for allgather,diag,eye,gather,lookup_table add complex support for allgather,diag,eye,gather,lookup_table_v2 Mar 17, 2024
Copy link
Contributor

@GGBond8488 GGBond8488 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还有,这些注册了复数kernel的应该在对应的python api的地方的类型校验以及英文文档处增加一个说明也支持complex

self.outputs = {'Out': np_out.reshape((2, 4, 64))}
self.attrs = {'start_index': self.start_index}

if core.is_compiled_with_xpu():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个xpu没必要吧,应该也没支持

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,我把这个删掉

@@ -26,5 +27,9 @@

TestCEmbeddingOpFP32()

TestCEmbeddingOpComplex64()

TestCEmbeddingOpComplex64()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要跑两次啊

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里写错了,第二个应该是complex128

@@ -247,7 +247,7 @@ void GatherV2GradFunction(const phi::CPUContext& ctx,
auto* out_data = ctx.Alloc<T>(out);
auto out_dim = out->dims();
int64_t out_index_dim_size = out_dim[axis_index];
phi::funcs::set_constant(ctx, out, static_cast<T>(0.0));
phi::funcs::set_constant(ctx, out, static_cast<float>(0.0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么这里会特别指定float,是因为paddle定义的complex 和 c++ 对应的complex 不一致导致不能cast吗,如果是的可以用这种方式using MT = typename phi::dtype::MPTypeTrait::Type; 转成c++ 的type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_constant的第三个参数是float类型的,在不注册complex的情况下static_cast(0.0)应该会隐士转化为float;注册complex后不能隐式转换,编译会出错,这里索性就直接改成了float。using MT = typename phi::dtype::MPTypeTrait::Type这个方式下的phi::dtype::complex并没有转为c++type,映射后的应该还是paddle自己实现的复数类型,不是c++标准库里的complex。然后如果添加这些

template <typename T>
class MPTypeTrait<phi::dtype::complex<T>> {
 public:
  using Type = std::complex<T>;
};

把phi::dtype::complex映射成std::complex,目前来看会对MatrixReduceSumFunctor造成影响,里面使用的__shfl_down_sync洗牌函数不支持std::complex,导致编译错误。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那这里就直接用float吧,注释一下这个set_constant 函数,就只支持float类型value的输入

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@GGBond8488 GGBond8488 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link

paddle-ci-bot bot commented Mar 27, 2024

Sorry to inform you that af60380's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@luotao1 luotao1 merged commit 0a2e7b6 into PaddlePaddle:develop Mar 29, 2024
30 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants