Skip to content

Commit

Permalink
rename OpExprInterpState to AutoGradCaptureState (#5918)
Browse files Browse the repository at this point in the history
* rename OpExprInterpState to AutoGradCaptureState

* fix

* fix

* delete comment

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
simonJJJ and oneflow-ci-bot committed Aug 19, 2021
1 parent f7d738a commit c2931ec
Show file tree
Hide file tree
Showing 69 changed files with 488 additions and 488 deletions.
61 changes: 31 additions & 30 deletions oneflow/core/autograd/gradient_funcs/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ limitations under the License.
namespace oneflow {
namespace one {

struct BaseActivationInterpState : public OpExprInterpState {
struct BaseActivationCaptureState : public AutoGradCaptureState {
bool requires_grad;
};

class BaseActivation : public OpExprGradFunction<BaseActivationInterpState> {
class BaseActivation : public OpExprGradFunction<BaseActivationCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Maybe<void> Capture(BaseActivationInterpState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(BaseActivationCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
Expand All @@ -39,7 +39,7 @@ class BaseActivation : public OpExprGradFunction<BaseActivationInterpState> {

class Silu : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -53,7 +53,7 @@ class Silu : public BaseActivation {

class Mish : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -67,7 +67,7 @@ class Mish : public BaseActivation {

class Selu : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -81,7 +81,7 @@ class Selu : public BaseActivation {

class Softsign : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -95,7 +95,7 @@ class Softsign : public BaseActivation {

class GeLU : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -109,7 +109,7 @@ class GeLU : public BaseActivation {

class HardSigmoid : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -123,7 +123,7 @@ class HardSigmoid : public BaseActivation {

class HardSwish : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -136,15 +136,15 @@ class HardSwish : public BaseActivation {
};

// ===== Activation with parms ====
struct ReLUInterpState : public OpExprInterpState {
struct ReLUCaptureState : public AutoGradCaptureState {
bool requires_grad;
};

class ReLU : public OpExprGradFunction<ReLUInterpState> {
class ReLU : public OpExprGradFunction<ReLUCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Maybe<void> Capture(ReLUInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
Maybe<void> Capture(ReLUCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
Expand All @@ -153,7 +153,7 @@ class ReLU : public OpExprGradFunction<ReLUInterpState> {
return Maybe<void>::Ok();
}

Maybe<void> Apply(const ReLUInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const ReLUCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -165,12 +165,13 @@ class ReLU : public OpExprGradFunction<ReLUInterpState> {
}
};

struct LeakyReluInterpState : public OpExprInterpState {
// ===== Activation with parms ====
struct LeakyReluCaptureState : public AutoGradCaptureState {
bool requires_grad;
float alpha;
};

class LeakyRelu : public OpExprGradFunction<LeakyReluInterpState> {
class LeakyRelu : public OpExprGradFunction<LeakyReluCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
Expand All @@ -179,7 +180,7 @@ class LeakyRelu : public OpExprGradFunction<LeakyReluInterpState> {
return Maybe<void>::Ok();
}

Maybe<void> Capture(LeakyReluInterpState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(LeakyReluCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
ctx->requires_grad = inputs.at(0)->requires_grad();
Expand All @@ -191,7 +192,7 @@ class LeakyRelu : public OpExprGradFunction<LeakyReluInterpState> {
return Maybe<void>::Ok();
}

Maybe<void> Apply(const LeakyReluInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const LeakyReluCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -206,13 +207,13 @@ class LeakyRelu : public OpExprGradFunction<LeakyReluInterpState> {
AttrMap base_attrs_;
};

struct HardTanhInterpState : public OpExprInterpState {
struct HardTanhCaptureState : public AutoGradCaptureState {
bool requires_grad;
double min_val;
double max_val;
};

class HardTanh : public OpExprGradFunction<HardTanhInterpState> {
class HardTanh : public OpExprGradFunction<HardTanhCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
Expand All @@ -221,7 +222,7 @@ class HardTanh : public OpExprGradFunction<HardTanhInterpState> {
return Maybe<void>::Ok();
}

Maybe<void> Capture(HardTanhInterpState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(HardTanhCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(outputs.size(), 1);
ctx->requires_grad = inputs.at(0)->requires_grad();
Expand All @@ -234,7 +235,7 @@ class HardTanh : public OpExprGradFunction<HardTanhInterpState> {
return Maybe<void>::Ok();
}

Maybe<void> Apply(const HardTanhInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const HardTanhCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -250,12 +251,12 @@ class HardTanh : public OpExprGradFunction<HardTanhInterpState> {
AttrMap base_attrs_;
};

struct EluInterpState : public OpExprInterpState {
struct EluCaptureState : public AutoGradCaptureState {
bool requires_grad;
double alpha;
};

class Elu : public OpExprGradFunction<EluInterpState> {
class Elu : public OpExprGradFunction<EluCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
Expand All @@ -264,7 +265,7 @@ class Elu : public OpExprGradFunction<EluInterpState> {
return Maybe<void>::Ok();
}

Maybe<void> Capture(EluInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
Maybe<void> Capture(EluCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
ctx->requires_grad = inputs.at(0)->requires_grad();
Expand All @@ -276,7 +277,7 @@ class Elu : public OpExprGradFunction<EluInterpState> {
return Maybe<void>::Ok();
}

Maybe<void> Apply(const EluInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const EluCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -291,16 +292,16 @@ class Elu : public OpExprGradFunction<EluInterpState> {
AttrMap base_attrs_;
};

struct PReLUInterpState : public OpExprInterpState {
struct PReLUCaptureState : public AutoGradCaptureState {
bool input_requires_grad;
bool alpha_requires_grad;
};

class PReLU : public OpExprGradFunction<PReLUInterpState> {
class PReLU : public OpExprGradFunction<PReLUCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Maybe<void> Capture(PReLUInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
Maybe<void> Capture(PReLUCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2);
ctx->input_requires_grad = inputs.at(0)->requires_grad(); // input
Expand All @@ -311,7 +312,7 @@ class PReLU : public OpExprGradFunction<PReLUInterpState> {
return Maybe<void>::Ok();
}

Maybe<void> Apply(const PReLUInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const PReLUCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
const auto& dy = out_grads.at(0);
Expand Down
14 changes: 7 additions & 7 deletions oneflow/core/autograd/gradient_funcs/adaptive_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@ limitations under the License.
namespace oneflow {
namespace one {

struct AdaptivePoolInterpState : public OpExprInterpState {
struct AdaptivePoolCaptureState : public AutoGradCaptureState {
bool requires_grad;
};

class AdaptivePoolNdGrad : public OpExprGradFunction<AdaptivePoolInterpState> {
class AdaptivePoolNdGrad : public OpExprGradFunction<AdaptivePoolCaptureState> {
public:
using OpExprGradFunction<AdaptivePoolInterpState>::Init;
using OpExprGradFunction<AdaptivePoolCaptureState>::Init;

Maybe<void> Init(const OpExpr& op, std::string mode, const int& ndims);
Maybe<void> Capture(AdaptivePoolInterpState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(AdaptivePoolCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const AdaptivePoolInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const AdaptivePoolCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

private:
Expand All @@ -52,7 +52,7 @@ Maybe<void> AdaptivePoolNdGrad::Init(const OpExpr& op, std::string mode, const i
return Maybe<void>::Ok();
}

Maybe<void> AdaptivePoolNdGrad::Capture(AdaptivePoolInterpState* ctx, const TensorTuple& inputs,
Maybe<void> AdaptivePoolNdGrad::Capture(AdaptivePoolCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
Expand All @@ -61,7 +61,7 @@ Maybe<void> AdaptivePoolNdGrad::Capture(AdaptivePoolInterpState* ctx, const Tens
return Maybe<void>::Ok();
}

Maybe<void> AdaptivePoolNdGrad::Apply(const AdaptivePoolInterpState* ctx,
Maybe<void> AdaptivePoolNdGrad::Apply(const AdaptivePoolCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
Expand Down
8 changes: 4 additions & 4 deletions oneflow/core/autograd/gradient_funcs/add_n.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ limitations under the License.
namespace oneflow {
namespace one {

struct AddNInterpState : public OpExprInterpState {
struct AddNCaptureState : public AutoGradCaptureState {
int32_t input_num;
std::vector<bool> requires_grad;
};

class AddN : public OpExprGradFunction<AddNInterpState> {
class AddN : public OpExprGradFunction<AddNCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Maybe<void> Capture(AddNInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
Maybe<void> Capture(AddNCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
ctx->input_num = inputs.size();
ctx->requires_grad.resize(inputs.size());
Expand All @@ -37,7 +37,7 @@ class AddN : public OpExprGradFunction<AddNInterpState> {
return Maybe<void>::Ok();
}

Maybe<void> Apply(const AddNInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const AddNCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(ctx->input_num);
Expand Down
12 changes: 6 additions & 6 deletions oneflow/core/autograd/gradient_funcs/avg_pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace one {

namespace {

struct AvgPoolingInterpState : public OpExprInterpState {
struct AvgPoolingCaptureState : public AutoGradCaptureState {
bool requires_grad;
size_t input_index;
size_t output_index;
Expand All @@ -40,13 +40,13 @@ struct AvgPoolingInterpState : public OpExprInterpState {
int64_t divisor_override;
};

class AvgPoolingNdGrad : public OpExprGradFunction<AvgPoolingInterpState> {
class AvgPoolingNdGrad : public OpExprGradFunction<AvgPoolingCaptureState> {
public:
virtual ~AvgPoolingNdGrad() = default;
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(AvgPoolingInterpState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(AvgPoolingCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const AvgPoolingInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const AvgPoolingCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

private:
Expand All @@ -60,7 +60,7 @@ Maybe<void> AvgPoolingNdGrad::Init(const OpExpr& op) {
return Maybe<void>::Ok();
}

Maybe<void> AvgPoolingNdGrad::Capture(AvgPoolingInterpState* ctx, const TensorTuple& inputs,
Maybe<void> AvgPoolingNdGrad::Capture(AvgPoolingCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
Expand All @@ -80,7 +80,7 @@ Maybe<void> AvgPoolingNdGrad::Capture(AvgPoolingInterpState* ctx, const TensorTu
return Maybe<void>::Ok();
}

Maybe<void> AvgPoolingNdGrad::Apply(const AvgPoolingInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> AvgPoolingNdGrad::Apply(const AvgPoolingCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
Expand Down
12 changes: 6 additions & 6 deletions oneflow/core/autograd/gradient_funcs/batch_gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ limitations under the License.
namespace oneflow {
namespace one {

struct BatchGatherInterpState : public OpExprInterpState {
struct BatchGatherCaptureState : public AutoGradCaptureState {
int64_t num_segments;
bool requires_grad;
};

class BatchGather : public OpExprGradFunction<BatchGatherInterpState> {
class BatchGather : public OpExprGradFunction<BatchGatherCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(BatchGatherInterpState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(BatchGatherCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const BatchGatherInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BatchGatherCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

private:
Expand All @@ -48,7 +48,7 @@ Maybe<void> BatchGather::Init(const OpExpr& op) {
return Maybe<void>::Ok();
}

Maybe<void> BatchGather::Capture(BatchGatherInterpState* ctx, const TensorTuple& inputs,
Maybe<void> BatchGather::Capture(BatchGatherCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
Expand All @@ -59,7 +59,7 @@ Maybe<void> BatchGather::Capture(BatchGatherInterpState* ctx, const TensorTuple&
return Maybe<void>::Ok();
}

Maybe<void> BatchGather::Apply(const BatchGatherInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> BatchGather::Apply(const BatchGatherCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
in_grads->resize(2);
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
Expand Down
Loading

0 comments on commit c2931ec

Please sign in to comment.