-
Notifications
You must be signed in to change notification settings - Fork 662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix reduce min max backward bug #5651
Conversation
const auto& bcast_like = | ||
JUST(OpInterpUtil::Dispatch<Tensor>(*bcast_like_op_, {output, input}, bcast_attrs)); | ||
const auto& bcast_eq = | ||
JUST(OpInterpUtil::Dispatch<Tensor>(*bcast_equal_op_, {input, bcast_like})); | ||
const auto& cast_like = JUST(OpInterpUtil::Dispatch<Tensor>(*cast_like_op_, {bcast_eq, input})); | ||
const auto& reduce_sum_ = | ||
JUST(OpInterpUtil::Dispatch<Tensor>(*reduce_sum_op_, {cast_like}, reduce_sum_attrs)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里用functional接口重写一下吧,
const auto& reduce_sum_ = JUST(functional::ReduceSum(cast_like, ctx->axis, ctx->keepdims));
其他地方也一样改一下,CastLike目前还没有functional接口,也一起加一个。
@@ -19,6 +19,7 @@ limitations under the License. | |||
#include "oneflow/core/framework/op_expr.h" | |||
#include "oneflow/core/framework/op_expr_helper.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op_expr_helper.h
这个头文件可以删了
}; | ||
|
||
Maybe<void> ReduceMaxOrMinOp::Init(const OpExpr& op) { | ||
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op); | ||
CHECK_NOTNULL_OR_RETURN(fw_op_expr); | ||
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); | ||
const std::string& op_name = fw_op_expr->op_name(); | ||
bcast_like_op_ = | ||
JUST(op_expr_helper::BroadcastLikeOp(/*axis=*/{-1}, GradientOpName(op_name + "_bcast_like"))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
顺便把这个文件中其他地方的op_expr_helper::XXXOp也一起改成functional的吧
const auto& bcast_eq = JUST(functional::BroadcastEqual(input, bcast_like)); | ||
const auto& cast_like = JUST(functional::CastLike(bcast_eq, input)); | ||
const auto& reduce_sum_ = JUST(functional::ReduceSum(cast_like, ctx->axis, ctx->keepdims)); | ||
const auto& bcast_div_ = JUST(functional::BroadcastDiv(dy, reduce_sum_)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bcast_div_
-> bcast_div
Speed stats:
|
When I referred to some implementations in reduce_min backward, I found there is different behaviour between oneflow and pytorch:
So I fix it!