Skip to content

Commit

Permalink
Fix BroadcastMatmulGrad bug (#6168)
Browse files Browse the repository at this point in the history
* fix(BroadcastMatmulGrad): fix BroadcastMatmulGrad bug

* remove useless code

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
wyg1997 and oneflow-ci-bot committed Sep 6, 2021
1 parent 05d2b79 commit 332c05c
Show file tree
Hide file tree
Showing 5 changed files with 343 additions and 283 deletions.
39 changes: 38 additions & 1 deletion oneflow/core/autograd/gradient_funcs/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,46 @@ Maybe<void> Matmul::Apply(const MatmulCaptureState* ctx, const TensorTuple& out_
return Maybe<void>::Ok();
}

class BroadcastMatmul : public Matmul {
public:
Maybe<void> Apply(const MatmulCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
};

Maybe<void> BroadcastMatmul::Apply(const MatmulCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);

in_grads->resize(2);
if (ctx->requires_grad_a) {
const auto& input_b = ctx->SavedTensors().at(ctx->b_index);
if (ctx->transpose_a) {
in_grads->at(0) =
JUST(functional::MatMul(input_b, out_grads.at(0), ctx->transpose_b, true, ctx->alpha));
} else {
in_grads->at(0) = JUST(
functional::MatMul(out_grads.at(0), input_b, false, !(ctx->transpose_b), ctx->alpha));
}
}

if (ctx->requires_grad_b) {
const auto& input_a = ctx->SavedTensors().at(ctx->a_index);
if (ctx->transpose_b) {
in_grads->at(1) =
JUST(functional::BroadcastMatmulGradB(out_grads.at(0), input_a, ctx->alpha));
} else {
in_grads->at(1) =
JUST(functional::BroadcastMatmulGradB(input_a, out_grads.at(0), ctx->alpha));
}
}

return Maybe<void>::Ok();
}

REGISTER_OP_EXPR_GRAD_FUNCTION("matmul", Matmul);
REGISTER_OP_EXPR_GRAD_FUNCTION("batch_matmul", Matmul);
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_matmul", Matmul);
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_matmul", BroadcastMatmul);

} // namespace one
} // namespace oneflow
2 changes: 0 additions & 2 deletions oneflow/core/autograd/gradient_funcs/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,9 @@ Maybe<void> Transpose::Apply(const TransposeCaptureState* ctx, const TensorTuple
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
MutableAttrMap attrs;
std::vector<int32_t> grad_perm;
grad_perm.resize(ctx->perm.size());
FOR_RANGE(int32_t, i, 0, ctx->perm.size()) { grad_perm.at(ctx->perm.at(i)) = i; }
JUST(attrs.SetAttr<std::vector<int32_t>>("perm", grad_perm));
in_grads->at(0) = JUST(functional::Transpose(out_grads.at(0), grad_perm));
return Maybe<void>::Ok();
}
Expand Down

0 comments on commit 332c05c

Please sign in to comment.