From 08a0fa1e9b17b1bed739f63f9c2670002e4f6452 Mon Sep 17 00:00:00 2001 From: ZZK <42901638+MARD1NO@users.noreply.github.com> Date: Tue, 17 Aug 2021 09:00:33 +0800 Subject: [PATCH] fix grad error (#5914) --- .../autograd/gradient_funcs/activation.cpp | 37 +++++++++++++------ 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/oneflow/core/autograd/gradient_funcs/activation.cpp b/oneflow/core/autograd/gradient_funcs/activation.cpp index 27bd001e784..1cb4643902f 100644 --- a/oneflow/core/autograd/gradient_funcs/activation.cpp +++ b/oneflow/core/autograd/gradient_funcs/activation.cpp @@ -93,7 +93,7 @@ class Softsign : public BaseActivation { } }; -class ReLU : public BaseActivation { +class GeLU : public BaseActivation { public: Maybe Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { @@ -101,13 +101,13 @@ class ReLU : public BaseActivation { in_grads->resize(1); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); - in_grads->at(0) = JUST(functional::ReluGrad(out_grads.at(0), x)); + in_grads->at(0) = JUST(functional::GeluGrad(out_grads.at(0), x)); } return Maybe::Ok(); } }; -class GeLU : public BaseActivation { +class HardSigmoid : public BaseActivation { public: Maybe Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { @@ -115,13 +115,13 @@ class GeLU : public BaseActivation { in_grads->resize(1); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); - in_grads->at(0) = JUST(functional::GeluGrad(out_grads.at(0), x)); + in_grads->at(0) = JUST(functional::HardSigmoidGrad(out_grads.at(0), x)); } return Maybe::Ok(); } }; -class HardSigmoid : public BaseActivation { +class HardSwish : public BaseActivation { public: Maybe Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { @@ -129,27 +129,42 @@ class HardSigmoid : public BaseActivation { in_grads->resize(1); if (ctx->requires_grad) { const auto& x = ctx->SavedTensors().at(0); - in_grads->at(0) = JUST(functional::HardSigmoidGrad(out_grads.at(0), x)); + in_grads->at(0) = JUST(functional::HardSwishGrad(out_grads.at(0), x)); } return Maybe::Ok(); } }; -class HardSwish : public BaseActivation { +// ===== Activation with parms ==== +struct ReLUInterpState : public OpExprInterpState { + bool requires_grad; +}; + +class ReLU : public OpExprGradFunction { public: - Maybe Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads, + Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } + + Maybe Capture(ReLUInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 1); + CHECK_EQ_OR_RETURN(outputs.size(), 1); + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (ctx->requires_grad) { ctx->SaveTensorForBackward(outputs.at(0)); } + return Maybe::Ok(); + } + + Maybe Apply(const ReLUInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); if (ctx->requires_grad) { - const auto& x = ctx->SavedTensors().at(0); - in_grads->at(0) = JUST(functional::HardSwishGrad(out_grads.at(0), x)); + const auto& y = ctx->SavedTensors().at(0); + in_grads->at(0) = JUST(functional::ReluGrad(out_grads.at(0), y)); } return Maybe::Ok(); } }; -// ===== Activation with parms ==== struct LeakyReluInterpState : public OpExprInterpState { bool requires_grad; float alpha;