Skip to content

Commit

Permalink
fix grad error (#5914)
Browse files Browse the repository at this point in the history
  • Loading branch information
MARD1NO committed Aug 17, 2021
1 parent c89b3ff commit 08a0fa1
Showing 1 changed file with 26 additions and 11 deletions.
37 changes: 26 additions & 11 deletions oneflow/core/autograd/gradient_funcs/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,63 +93,78 @@ class Softsign : public BaseActivation {
}
};

class ReLU : public BaseActivation {
class GeLU : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::ReluGrad(out_grads.at(0), x));
in_grads->at(0) = JUST(functional::GeluGrad(out_grads.at(0), x));
}
return Maybe<void>::Ok();
}
};

class GeLU : public BaseActivation {
class HardSigmoid : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::GeluGrad(out_grads.at(0), x));
in_grads->at(0) = JUST(functional::HardSigmoidGrad(out_grads.at(0), x));
}
return Maybe<void>::Ok();
}
};

class HardSigmoid : public BaseActivation {
class HardSwish : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::HardSigmoidGrad(out_grads.at(0), x));
in_grads->at(0) = JUST(functional::HardSwishGrad(out_grads.at(0), x));
}
return Maybe<void>::Ok();
}
};

class HardSwish : public BaseActivation {
// ===== Activation with parms ====
struct ReLUInterpState : public OpExprInterpState {
bool requires_grad;
};

class ReLU : public OpExprGradFunction<ReLUInterpState> {
public:
Maybe<void> Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Maybe<void> Capture(ReLUInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
ctx->requires_grad = inputs.at(0)->requires_grad();
if (ctx->requires_grad) { ctx->SaveTensorForBackward(outputs.at(0)); }
return Maybe<void>::Ok();
}

Maybe<void> Apply(const ReLUInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::HardSwishGrad(out_grads.at(0), x));
const auto& y = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::ReluGrad(out_grads.at(0), y));
}
return Maybe<void>::Ok();
}
};

// ===== Activation with parms ====
struct LeakyReluInterpState : public OpExprInterpState {
bool requires_grad;
float alpha;
Expand Down

0 comments on commit 08a0fa1

Please sign in to comment.