-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 4 No.19】Add polygamma API to Paddle #53791
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
❌ The PR is not created using PR's template. You can refer to this Demo. |
https://github.com/PaddlePaddle/Paddle/blob/91a77fe079ad334e75a122168a8338fa3acd0a97/python/paddle/tensor/math.py#L5684-L5686 测试的结果如下: 这是否可认为是一个 bug,根据定义 |
@PommesPeter 你好,这里的确是有行为上的差异,torch从1.8版本开始将x=0点的输出从nan改为了-inf。想确认下这个行为差异会对polygamma的实现带来多少影响?除了x=0处输出的结果外,还会有其他的差异吗。 |
@PommesPeter 另外,从目前的实现看,新增了相关OP,这与RFC方案中基于现有API组合实现的方式有差异。出于飞桨多硬件适配等方面的考虑,是期望尽量减少基础算子的数量的。请问这里新增算子是否是因为遇到了RFC方案不可解决的问题,或是有其他特殊的考虑吗? |
目前只发现 x=0 处的结果有问题,其他的不会有差异。 |
是看到该算子初期的任务要求是需要使用 C++ 实现,参考 torch 实现后,发现使用 zeta 函数能够减少递归带来的运算量,rfc 文档里面预期是采用递归的方式,但该操作性能影响可能会较大,考虑算子计算性能故参考 torch 的实现。 |
辛苦也对应在RFC文档中,提交PR修改下OP正反向相关计算逻辑的介绍吧 |
好的,已更新 PaddlePaddle/community#542 |
rfc 文档已更新,麻烦 review 一下 @zoooo0820 |
@PommesPeter 你好,关于
|
那目前的解决方案是继续保持该点的取值为 nan,剔除 zero case 的情况么?因为这个实现和目前 scipy 和 pytorch 不统一。目前 CI 仅存在该点的误差。 通过调研情况来看, |
当前保持现状即可, |
对的,是可以直接复用 digamma,那我参考 digamma 的单测让 CI 通过。 |
可以的 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CI-Approval 有一条关于检测到使用了std::cout / print的报告,辛苦再确认下是否是误报呢。
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是否可以扩展更多dtype呢,如fp16等。kernel注册dtype时,尽量把目前理论上应当支持的,同时Paddle框架机制上也支持的dtype都包含进来,否则在某些特定场景会有问题。
在RFC设计阶段中,只支持fp32/fp64的原因主要是此前的方案是在digamma上进行,而digamma kernel支持的有限。这个后续也会通过其他专项任务去逐步扩展
return _C_ops.polygamma(x, n) | ||
else: | ||
check_variable_and_dtype( | ||
x, "x", ["float32", "float64"], "polygamma" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kernel数据类型扩展后,这里可以相应放宽数据类型的支持
是代码示例中用了print, @tianshuo78520a 在做规则增强,等该 PR 全部 ready 后,可以豁免 |
好的 |
请修复下CI问题,可以merge develop |
好的,正在修复 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Description
rfc doc here: PaddlePaddle/community#472
updated rfc doc here: PaddlePaddle/community#542
polygamma doc here: PaddlePaddle/docs#5913