-
Notifications
You must be signed in to change notification settings - Fork 756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dev Zeta op #10189
dev Zeta op #10189
Conversation
@@ -175,6 +175,8 @@ inline bool IsDimsEquals(size_t num_src0_dims, const int64_t* src0_dims, size_t | |||
BINARY_MATH_BACKWARD_OP_SEQ_2 \ | |||
BINARY_MATH_BACKWARD_OP_SEQ_3 | |||
|
|||
#define BINARY_MATH_FLOATING_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kZeta) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以考虑将这个宏定义放到前面,和 Forward 的放到一起。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以考虑将这个宏定义放到前面,和 Forward 的放到一起。
已修改
oneflow/core/ndarray/binary_func.h
Outdated
@@ -284,6 +284,11 @@ struct BinaryFuncINN final { | |||
// placeholder, no definition required, the type is only used to generate Op | |||
}; | |||
|
|||
template<typename T> | |||
struct BinaryFuncZeta final { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我觉得暂时可以用 BinaryFuncZeta 来规避编译错误。可以在注释中说明一下,是因为 GetBinaryBroadcastSbpSignature 的模版参数需要,才定义这个结构。
感觉 binary_func.h 的大部分代码都已失去最初的作用。大部分代码的逻辑应该都被 binary_functor.h 及其 CUDA 版本替代了。目前 binary_func.h 的作用可能主要是 GetBinaryBroadcastSbpSignature 需要,以及最近几个月新增加的 BinaryFuncNanSum, BinaryFuncIEN, BinaryFuncINN。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加
namespace primitive { | ||
namespace broadcast_elementwise_binary { | ||
|
||
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉这个.cu里并没有实际计算,定义的部分是不是放broadcast_elementwise_binary.cuh
里就行?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉这个.cu里并没有实际计算,定义的部分是不是放
broadcast_elementwise_binary.cuh
里就行?
好的,这里就注册了下,我把它移到broadcast_elementwise_binary.cuh
里
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉这个.cu里并没有实际计算,定义的部分是不是放
broadcast_elementwise_binary.cuh
里就行?好的,这里就注册了下,我把它移到
broadcast_elementwise_binary.cuh
里
已修改
zeta函数,torch对应文档
![image](https://user-images.githubusercontent.com/41790911/234006002-7c3984b9-c763-4fb5-b529-226717c37d7c.png)