diff --git a/oneflow/core/autograd/gradient_funcs/activation.cpp b/oneflow/core/autograd/gradient_funcs/activation.cpp index 1cb4643902f..42d8ba4bf2a 100644 --- a/oneflow/core/autograd/gradient_funcs/activation.cpp +++ b/oneflow/core/autograd/gradient_funcs/activation.cpp @@ -19,15 +19,15 @@ limitations under the License. namespace oneflow { namespace one { -struct BaseActivationInterpState : public OpExprInterpState { +struct BaseActivationCaptureState : public AutoGradCaptureState { bool requires_grad; }; -class BaseActivation : public OpExprGradFunction { +class BaseActivation : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(BaseActivationInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(BaseActivationCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); @@ -39,7 +39,7 @@ class BaseActivation : public OpExprGradFunction { class Silu : public BaseActivation { public: - Maybe Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); @@ -53,7 +53,7 @@ class Silu : public BaseActivation { class Mish : public BaseActivation { public: - Maybe Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); @@ -67,7 +67,7 @@ class Mish : public BaseActivation { class Selu : public BaseActivation { public: - Maybe Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); @@ -81,7 +81,7 @@ class Selu : public BaseActivation { class Softsign : public BaseActivation { public: - Maybe Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); @@ -95,7 +95,7 @@ class Softsign : public BaseActivation { class GeLU : public BaseActivation { public: - Maybe Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); @@ -109,7 +109,7 @@ class GeLU : public BaseActivation { class HardSigmoid : public BaseActivation { public: - Maybe Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); @@ -123,7 +123,7 @@ class HardSigmoid : public BaseActivation { class HardSwish : public BaseActivation { public: - Maybe Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); @@ -136,15 +136,15 @@ class HardSwish : public BaseActivation { }; // ===== Activation with parms ==== -struct ReLUInterpState : public OpExprInterpState { +struct ReLUCaptureState : public AutoGradCaptureState { bool requires_grad; }; -class ReLU : public OpExprGradFunction { +class ReLU : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(ReLUInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + Maybe Capture(ReLUCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); @@ -153,7 +153,7 @@ class ReLU : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const ReLUInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const ReLUCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); @@ -165,12 +165,13 @@ class ReLU : public OpExprGradFunction { } }; -struct LeakyReluInterpState : public OpExprInterpState { +// ===== Activation with parms ==== +struct LeakyReluCaptureState : public AutoGradCaptureState { bool requires_grad; float alpha; }; -class LeakyRelu : public OpExprGradFunction { +class LeakyRelu : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -179,7 +180,7 @@ class LeakyRelu : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Capture(LeakyReluInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(LeakyReluCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); ctx->requires_grad = inputs.at(0)->requires_grad(); @@ -191,7 +192,7 @@ class LeakyRelu : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const LeakyReluInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const LeakyReluCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); @@ -206,13 +207,13 @@ class LeakyRelu : public OpExprGradFunction { AttrMap base_attrs_; }; -struct HardTanhInterpState : public OpExprInterpState { +struct HardTanhCaptureState : public AutoGradCaptureState { bool requires_grad; double min_val; double max_val; }; -class HardTanh : public OpExprGradFunction { +class HardTanh : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -221,7 +222,7 @@ class HardTanh : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Capture(HardTanhInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(HardTanhCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(outputs.size(), 1); ctx->requires_grad = inputs.at(0)->requires_grad(); @@ -234,7 +235,7 @@ class HardTanh : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const HardTanhInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const HardTanhCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); @@ -250,12 +251,12 @@ class HardTanh : public OpExprGradFunction { AttrMap base_attrs_; }; -struct EluInterpState : public OpExprInterpState { +struct EluCaptureState : public AutoGradCaptureState { bool requires_grad; double alpha; }; -class Elu : public OpExprGradFunction { +class Elu : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -264,7 +265,7 @@ class Elu : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Capture(EluInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + Maybe Capture(EluCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); ctx->requires_grad = inputs.at(0)->requires_grad(); @@ -276,7 +277,7 @@ class Elu : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const EluInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const EluCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); @@ -291,16 +292,16 @@ class Elu : public OpExprGradFunction { AttrMap base_attrs_; }; -struct PReLUInterpState : public OpExprInterpState { +struct PReLUCaptureState : public AutoGradCaptureState { bool input_requires_grad; bool alpha_requires_grad; }; -class PReLU : public OpExprGradFunction { +class PReLU : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(PReLUInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + Maybe Capture(PReLUCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); ctx->input_requires_grad = inputs.at(0)->requires_grad(); // input @@ -311,7 +312,7 @@ class PReLU : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const PReLUInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const PReLUCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); const auto& dy = out_grads.at(0); diff --git a/oneflow/core/autograd/gradient_funcs/adaptive_pool.cpp b/oneflow/core/autograd/gradient_funcs/adaptive_pool.cpp index cc4009a8098..2bef644efd2 100644 --- a/oneflow/core/autograd/gradient_funcs/adaptive_pool.cpp +++ b/oneflow/core/autograd/gradient_funcs/adaptive_pool.cpp @@ -23,18 +23,18 @@ limitations under the License. namespace oneflow { namespace one { -struct AdaptivePoolInterpState : public OpExprInterpState { +struct AdaptivePoolCaptureState : public AutoGradCaptureState { bool requires_grad; }; -class AdaptivePoolNdGrad : public OpExprGradFunction { +class AdaptivePoolNdGrad : public OpExprGradFunction { public: - using OpExprGradFunction::Init; + using OpExprGradFunction::Init; Maybe Init(const OpExpr& op, std::string mode, const int& ndims); - Maybe Capture(AdaptivePoolInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(AdaptivePoolCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const AdaptivePoolInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const AdaptivePoolCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -52,7 +52,7 @@ Maybe AdaptivePoolNdGrad::Init(const OpExpr& op, std::string mode, const i return Maybe::Ok(); } -Maybe AdaptivePoolNdGrad::Capture(AdaptivePoolInterpState* ctx, const TensorTuple& inputs, +Maybe AdaptivePoolNdGrad::Capture(AdaptivePoolCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -61,7 +61,7 @@ Maybe AdaptivePoolNdGrad::Capture(AdaptivePoolInterpState* ctx, const Tens return Maybe::Ok(); } -Maybe AdaptivePoolNdGrad::Apply(const AdaptivePoolInterpState* ctx, +Maybe AdaptivePoolNdGrad::Apply(const AdaptivePoolCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); diff --git a/oneflow/core/autograd/gradient_funcs/add_n.cpp b/oneflow/core/autograd/gradient_funcs/add_n.cpp index 5c083ffd587..2748de7a063 100644 --- a/oneflow/core/autograd/gradient_funcs/add_n.cpp +++ b/oneflow/core/autograd/gradient_funcs/add_n.cpp @@ -18,16 +18,16 @@ limitations under the License. namespace oneflow { namespace one { -struct AddNInterpState : public OpExprInterpState { +struct AddNCaptureState : public AutoGradCaptureState { int32_t input_num; std::vector requires_grad; }; -class AddN : public OpExprGradFunction { +class AddN : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(AddNInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + Maybe Capture(AddNCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->input_num = inputs.size(); ctx->requires_grad.resize(inputs.size()); @@ -37,7 +37,7 @@ class AddN : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const AddNInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const AddNCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(ctx->input_num); diff --git a/oneflow/core/autograd/gradient_funcs/avg_pooling.cpp b/oneflow/core/autograd/gradient_funcs/avg_pooling.cpp index f0bd56a2d8e..526b4458e65 100644 --- a/oneflow/core/autograd/gradient_funcs/avg_pooling.cpp +++ b/oneflow/core/autograd/gradient_funcs/avg_pooling.cpp @@ -26,7 +26,7 @@ namespace one { namespace { -struct AvgPoolingInterpState : public OpExprInterpState { +struct AvgPoolingCaptureState : public AutoGradCaptureState { bool requires_grad; size_t input_index; size_t output_index; @@ -40,13 +40,13 @@ struct AvgPoolingInterpState : public OpExprInterpState { int64_t divisor_override; }; -class AvgPoolingNdGrad : public OpExprGradFunction { +class AvgPoolingNdGrad : public OpExprGradFunction { public: virtual ~AvgPoolingNdGrad() = default; Maybe Init(const OpExpr& op) override; - Maybe Capture(AvgPoolingInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(AvgPoolingCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const AvgPoolingInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const AvgPoolingCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -60,7 +60,7 @@ Maybe AvgPoolingNdGrad::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe AvgPoolingNdGrad::Capture(AvgPoolingInterpState* ctx, const TensorTuple& inputs, +Maybe AvgPoolingNdGrad::Capture(AvgPoolingCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -80,7 +80,7 @@ Maybe AvgPoolingNdGrad::Capture(AvgPoolingInterpState* ctx, const TensorTu return Maybe::Ok(); } -Maybe AvgPoolingNdGrad::Apply(const AvgPoolingInterpState* ctx, const TensorTuple& out_grads, +Maybe AvgPoolingNdGrad::Apply(const AvgPoolingCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); diff --git a/oneflow/core/autograd/gradient_funcs/batch_gather.cpp b/oneflow/core/autograd/gradient_funcs/batch_gather.cpp index bfba2bd0c6c..b21fc693481 100644 --- a/oneflow/core/autograd/gradient_funcs/batch_gather.cpp +++ b/oneflow/core/autograd/gradient_funcs/batch_gather.cpp @@ -22,17 +22,17 @@ limitations under the License. namespace oneflow { namespace one { -struct BatchGatherInterpState : public OpExprInterpState { +struct BatchGatherCaptureState : public AutoGradCaptureState { int64_t num_segments; bool requires_grad; }; -class BatchGather : public OpExprGradFunction { +class BatchGather : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(BatchGatherInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(BatchGatherCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const BatchGatherInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const BatchGatherCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -48,7 +48,7 @@ Maybe BatchGather::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe BatchGather::Capture(BatchGatherInterpState* ctx, const TensorTuple& inputs, +Maybe BatchGather::Capture(BatchGatherCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -59,7 +59,7 @@ Maybe BatchGather::Capture(BatchGatherInterpState* ctx, const TensorTuple& return Maybe::Ok(); } -Maybe BatchGather::Apply(const BatchGatherInterpState* ctx, const TensorTuple& out_grads, +Maybe BatchGather::Apply(const BatchGatherCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { in_grads->resize(2); if (!ctx->requires_grad) { return Maybe::Ok(); } diff --git a/oneflow/core/autograd/gradient_funcs/bias_add.cpp b/oneflow/core/autograd/gradient_funcs/bias_add.cpp index 04378ee553e..6c2a52c1959 100644 --- a/oneflow/core/autograd/gradient_funcs/bias_add.cpp +++ b/oneflow/core/autograd/gradient_funcs/bias_add.cpp @@ -23,13 +23,13 @@ limitations under the License. namespace oneflow { namespace one { -struct BiasAddInterpState : public OpExprInterpState { +struct BiasAddCaptureState : public AutoGradCaptureState { bool input_requires_grad; bool bias_requires_grad; int32_t axis; }; -class BiasAdd : public OpExprGradFunction { +class BiasAdd : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -42,7 +42,7 @@ class BiasAdd : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Capture(BiasAddInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(BiasAddCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); ctx->input_requires_grad = inputs.at(0)->requires_grad(); @@ -52,7 +52,7 @@ class BiasAdd : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const BiasAddInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const BiasAddCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const int64_t num_axes = out_grads.at(0)->shape()->NumAxes(); in_grads->resize(2); diff --git a/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp b/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp index b5c0ef859f9..d3d7eb28588 100644 --- a/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp +++ b/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp @@ -21,15 +21,15 @@ limitations under the License. namespace oneflow { namespace one { -class BroadcastBinaryGrad : public OpExprGradFunction { +class BroadcastBinaryGrad : public OpExprGradFunction { public: BroadcastBinaryGrad() = default; virtual ~BroadcastBinaryGrad() = default; virtual Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(OpExprInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, - const AttrMap& attrs) const override { + Maybe Capture(AutoGradCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); CHECK_EQ_OR_RETURN(outputs.size(), 1); ctx->SaveTensorForBackward(inputs.at(0)); @@ -41,7 +41,7 @@ class BroadcastBinaryGrad : public OpExprGradFunction { class BroadcastAdd : public BroadcastBinaryGrad { public: - Maybe Apply(const OpExprInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const AutoGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& x = ctx->SavedTensors().at(0); const auto& y = ctx->SavedTensors().at(1); @@ -60,7 +60,7 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_add", BroadcastAdd); class BroadcastSub : public BroadcastBinaryGrad { public: - Maybe Apply(const OpExprInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const AutoGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& x = ctx->SavedTensors().at(0); const auto& y = ctx->SavedTensors().at(1); @@ -80,7 +80,7 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_sub", BroadcastSub); class BroadcastMul : public BroadcastBinaryGrad { public: - Maybe Apply(const OpExprInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const AutoGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& x = ctx->SavedTensors().at(0); const auto& y = ctx->SavedTensors().at(1); @@ -101,7 +101,7 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_mul", BroadcastMul); class BroadcastDiv : public BroadcastBinaryGrad { public: - Maybe Apply(const OpExprInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const AutoGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& x = ctx->SavedTensors().at(0); const auto& y = ctx->SavedTensors().at(1); @@ -122,7 +122,7 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_div", BroadcastDiv); class BroadcastMinMax : public BroadcastBinaryGrad { public: - Maybe Apply(const OpExprInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const AutoGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& x = ctx->SavedTensors().at(0); const auto& y = ctx->SavedTensors().at(1); diff --git a/oneflow/core/autograd/gradient_funcs/broadcast_floor_mod.cpp b/oneflow/core/autograd/gradient_funcs/broadcast_floor_mod.cpp index 47c9217b6b6..b1a58cc3403 100644 --- a/oneflow/core/autograd/gradient_funcs/broadcast_floor_mod.cpp +++ b/oneflow/core/autograd/gradient_funcs/broadcast_floor_mod.cpp @@ -19,22 +19,22 @@ limitations under the License. namespace oneflow { namespace one { -struct BroadcastFModInterpState : public OpExprInterpState { +struct BroadcastFModCaptureState : public AutoGradCaptureState { bool requires_grad; }; -class BroadcastFMod : public OpExprGradFunction { +class BroadcastFMod : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(BroadcastFModInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(BroadcastFModCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); ctx->requires_grad = inputs.at(0)->requires_grad(); return Maybe::Ok(); } - Maybe Apply(const BroadcastFModInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const BroadcastFModCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(2); diff --git a/oneflow/core/autograd/gradient_funcs/broadcast_like.cpp b/oneflow/core/autograd/gradient_funcs/broadcast_like.cpp index b1571e5413f..7968fae8ad7 100644 --- a/oneflow/core/autograd/gradient_funcs/broadcast_like.cpp +++ b/oneflow/core/autograd/gradient_funcs/broadcast_like.cpp @@ -22,17 +22,17 @@ limitations under the License. namespace oneflow { namespace one { -struct BroadCastLikeInterpState : public OpExprInterpState { +struct BroadCastLikeCaptureState : public AutoGradCaptureState { bool requires_grad; std::vector broadcast_axes; }; -class BroadCastLike : public OpExprGradFunction { +class BroadCastLike : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(BroadCastLikeInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(BroadCastLikeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const BroadCastLikeInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const BroadCastLikeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -50,7 +50,7 @@ Maybe BroadCastLike::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe BroadCastLike::Capture(BroadCastLikeInterpState* ctx, const TensorTuple& inputs, +Maybe BroadCastLike::Capture(BroadCastLikeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -61,7 +61,7 @@ Maybe BroadCastLike::Capture(BroadCastLikeInterpState* ctx, const TensorTu return Maybe::Ok(); } -Maybe BroadCastLike::Apply(const BroadCastLikeInterpState* ctx, const TensorTuple& out_grads, +Maybe BroadCastLike::Apply(const BroadCastLikeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); diff --git a/oneflow/core/autograd/gradient_funcs/cast.cpp b/oneflow/core/autograd/gradient_funcs/cast.cpp index b4f60b19028..7a8934b0133 100644 --- a/oneflow/core/autograd/gradient_funcs/cast.cpp +++ b/oneflow/core/autograd/gradient_funcs/cast.cpp @@ -24,11 +24,11 @@ limitations under the License. namespace oneflow { namespace one { -struct CastOpExprInterpState : public OpExprInterpState { +struct CastCaptureState : public AutoGradCaptureState { DataType data_type; }; -class Cast : public OpExprGradFunction { +class Cast : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -38,13 +38,13 @@ class Cast : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Capture(CastOpExprInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrMap& attrs) const override { + Maybe Capture(CastCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + const AttrMap& attrs) const override { ctx->data_type = inputs.at(0)->dtype()->data_type(); return Maybe::Ok(); } - Maybe Apply(const CastOpExprInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const CastCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(1); MutableAttrMap attrs; diff --git a/oneflow/core/autograd/gradient_funcs/clip_by_scalar.cpp b/oneflow/core/autograd/gradient_funcs/clip_by_scalar.cpp index 7fa5963c770..f87dbac7ddd 100644 --- a/oneflow/core/autograd/gradient_funcs/clip_by_scalar.cpp +++ b/oneflow/core/autograd/gradient_funcs/clip_by_scalar.cpp @@ -19,13 +19,13 @@ limitations under the License. namespace oneflow { namespace one { -struct ClipByScalarInterpState : public OpExprInterpState { +struct ClipByScalarCaptureState : public AutoGradCaptureState { bool requires_grad; functional::Scalar min; functional::Scalar max; }; -class ClipByScalar : public OpExprGradFunction { +class ClipByScalar : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -34,7 +34,7 @@ class ClipByScalar : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Capture(ClipByScalarInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(ClipByScalarCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); ctx->requires_grad = inputs.at(0)->requires_grad(); @@ -54,7 +54,7 @@ class ClipByScalar : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const ClipByScalarInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const ClipByScalarCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); diff --git a/oneflow/core/autograd/gradient_funcs/clip_by_scalar_max.cpp b/oneflow/core/autograd/gradient_funcs/clip_by_scalar_max.cpp index 587740aef97..291d56d81fd 100644 --- a/oneflow/core/autograd/gradient_funcs/clip_by_scalar_max.cpp +++ b/oneflow/core/autograd/gradient_funcs/clip_by_scalar_max.cpp @@ -19,12 +19,12 @@ limitations under the License. namespace oneflow { namespace one { -struct ClipByScalarMaxInterpState : public OpExprInterpState { +struct ClipByScalarMaxCaptureState : public AutoGradCaptureState { bool requires_grad; functional::Scalar max; }; -class ClipByScalarMax : public OpExprGradFunction { +class ClipByScalarMax : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -33,7 +33,7 @@ class ClipByScalarMax : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Capture(ClipByScalarMaxInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(ClipByScalarMaxCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); ctx->requires_grad = inputs.at(0)->requires_grad(); @@ -51,7 +51,7 @@ class ClipByScalarMax : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const ClipByScalarMaxInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const ClipByScalarMaxCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); diff --git a/oneflow/core/autograd/gradient_funcs/clip_by_scalar_min.cpp b/oneflow/core/autograd/gradient_funcs/clip_by_scalar_min.cpp index f73470dfccd..35ebc620b51 100644 --- a/oneflow/core/autograd/gradient_funcs/clip_by_scalar_min.cpp +++ b/oneflow/core/autograd/gradient_funcs/clip_by_scalar_min.cpp @@ -19,12 +19,12 @@ limitations under the License. namespace oneflow { namespace one { -struct ClipByScalarMinInterpState : public OpExprInterpState { +struct ClipByScalarMinCaptureState : public AutoGradCaptureState { bool requires_grad; functional::Scalar min; }; -class ClipByScalarMin : public OpExprGradFunction { +class ClipByScalarMin : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -33,7 +33,7 @@ class ClipByScalarMin : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Capture(ClipByScalarMinInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(ClipByScalarMinCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); ctx->requires_grad = inputs.at(0)->requires_grad(); @@ -51,7 +51,7 @@ class ClipByScalarMin : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const ClipByScalarMinInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const ClipByScalarMinCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); diff --git a/oneflow/core/autograd/gradient_funcs/combined_margin_loss.cpp b/oneflow/core/autograd/gradient_funcs/combined_margin_loss.cpp index 3f22cc5b7ca..9aea3f43512 100644 --- a/oneflow/core/autograd/gradient_funcs/combined_margin_loss.cpp +++ b/oneflow/core/autograd/gradient_funcs/combined_margin_loss.cpp @@ -21,7 +21,7 @@ limitations under the License. namespace oneflow { namespace one { -struct CombinedMarginLossInterpState : public OpExprInterpState { +struct CombinedMarginLossCaptureState : public AutoGradCaptureState { float m1; float m2; float m3; @@ -31,7 +31,7 @@ struct CombinedMarginLossInterpState : public OpExprInterpState { bool requires_grad; }; -class CombinedMarginLoss : public OpExprGradFunction { +class CombinedMarginLoss : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -40,7 +40,7 @@ class CombinedMarginLoss : public OpExprGradFunction::Ok(); } - Maybe Capture(CombinedMarginLossInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(CombinedMarginLossCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); ctx->requires_grad = inputs.at(0)->requires_grad(); // x @@ -57,7 +57,7 @@ class CombinedMarginLoss : public OpExprGradFunction::Ok(); } - Maybe Apply(const CombinedMarginLossInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const CombinedMarginLossCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 2); in_grads->resize(2); diff --git a/oneflow/core/autograd/gradient_funcs/concat.cpp b/oneflow/core/autograd/gradient_funcs/concat.cpp index 9bbb13aa50d..3d71b118f6f 100644 --- a/oneflow/core/autograd/gradient_funcs/concat.cpp +++ b/oneflow/core/autograd/gradient_funcs/concat.cpp @@ -23,18 +23,18 @@ limitations under the License. namespace oneflow { namespace one { -struct ConcatInterpState : public OpExprInterpState { +struct ConcatCaptureState : public AutoGradCaptureState { std::vector requires_grad; int64_t axis; int64_t input_num; }; -class Concat : public OpExprGradFunction { +class Concat : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(ConcatInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, - const AttrMap& attrs) const override; - Maybe Apply(const ConcatInterpState* ctx, const TensorTuple& out_grads, + Maybe Capture(ConcatCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override; + Maybe Apply(const ConcatCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -55,7 +55,7 @@ Maybe Concat::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe Concat::Capture(ConcatInterpState* ctx, const TensorTuple& inputs, +Maybe Concat::Capture(ConcatCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad.resize(inputs.size()); for (int i = 0; i < inputs.size(); ++i) { ctx->requires_grad[i] = inputs.at(i)->requires_grad(); } @@ -67,7 +67,7 @@ Maybe Concat::Capture(ConcatInterpState* ctx, const TensorTuple& inputs, return Maybe::Ok(); } -Maybe Concat::Apply(const ConcatInterpState* ctx, const TensorTuple& out_grads, +Maybe Concat::Apply(const ConcatCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(ctx->input_num); diff --git a/oneflow/core/autograd/gradient_funcs/consistent_cast.cpp b/oneflow/core/autograd/gradient_funcs/consistent_cast.cpp index 284b8f824a8..87f7eacdd5b 100644 --- a/oneflow/core/autograd/gradient_funcs/consistent_cast.cpp +++ b/oneflow/core/autograd/gradient_funcs/consistent_cast.cpp @@ -22,13 +22,13 @@ limitations under the License. namespace oneflow { namespace one { -struct CastConsistentOpExprInterpState : public OpExprInterpState { +struct CastConsistentCaptureState : public AutoGradCaptureState { Symbol parallel_desc; Symbol nd_sbp; std::shared_ptr shape; }; -class CastToConsistent : public OpExprGradFunction { +class CastToConsistent : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -38,7 +38,7 @@ class CastToConsistent : public OpExprGradFunction::Ok(); } - Maybe Capture(CastConsistentOpExprInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(CastConsistentCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const OpExprInterpContext& interp_ctx) const override { ctx->parallel_desc = JUST(interp_ctx.parallel_desc.value()); @@ -46,7 +46,7 @@ class CastToConsistent : public OpExprGradFunction::Ok(); } - Maybe Apply(const CastConsistentOpExprInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const CastConsistentCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& out_grad = out_grads.at(0); CHECK_OR_RETURN(out_grad->is_consistent()); @@ -63,7 +63,7 @@ class CastToConsistent : public OpExprGradFunction { +class CastFromConsistent : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -73,7 +73,7 @@ class CastFromConsistent : public OpExprGradFunction::Ok(); } - Maybe Capture(CastConsistentOpExprInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(CastConsistentCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { const auto& input = inputs.at(0); CHECK_OR_RETURN(input->is_consistent()); @@ -83,7 +83,7 @@ class CastFromConsistent : public OpExprGradFunction::Ok(); } - Maybe Apply(const CastConsistentOpExprInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const CastConsistentCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& dual_nd_sbp = JUST(GetDualNdSbp(ctx->nd_sbp)); MutableAttrMap attrs; diff --git a/oneflow/core/autograd/gradient_funcs/conv.cpp b/oneflow/core/autograd/gradient_funcs/conv.cpp index 8080a931618..cbe84ba5f15 100644 --- a/oneflow/core/autograd/gradient_funcs/conv.cpp +++ b/oneflow/core/autograd/gradient_funcs/conv.cpp @@ -23,7 +23,7 @@ limitations under the License. namespace oneflow { namespace one { -struct ConvolutionNdInterpState : public OpExprInterpState { +struct ConvolutionNdCaptureState : public AutoGradCaptureState { bool input_requires_grad = false; bool weight_requires_grad = false; size_t input_index; @@ -37,12 +37,12 @@ struct ConvolutionNdInterpState : public OpExprInterpState { int32_t groups; }; -class ConvolutionNd : public OpExprGradFunction { +class ConvolutionNd : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(ConvolutionNdInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(ConvolutionNdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const ConvolutionNdInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const ConvolutionNdCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -56,7 +56,7 @@ Maybe ConvolutionNd::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe ConvolutionNd::Capture(ConvolutionNdInterpState* ctx, const TensorTuple& inputs, +Maybe ConvolutionNd::Capture(ConvolutionNdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 2); ctx->input_requires_grad = inputs.at(0)->requires_grad(); @@ -77,7 +77,7 @@ Maybe ConvolutionNd::Capture(ConvolutionNdInterpState* ctx, const TensorTu return Maybe::Ok(); } -Maybe ConvolutionNd::Apply(const ConvolutionNdInterpState* ctx, const TensorTuple& out_grads, +Maybe ConvolutionNd::Apply(const ConvolutionNdCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { in_grads->resize(2); size_t num_spatial_dims = ctx->kernel_size.size(); diff --git a/oneflow/core/autograd/gradient_funcs/copy.cpp b/oneflow/core/autograd/gradient_funcs/copy.cpp index d84665db3ae..4798ec05e52 100644 --- a/oneflow/core/autograd/gradient_funcs/copy.cpp +++ b/oneflow/core/autograd/gradient_funcs/copy.cpp @@ -23,12 +23,12 @@ limitations under the License. namespace oneflow { namespace one { -struct CopyOpExprInterpState : public OpExprInterpState { +struct CopyCaptureState : public AutoGradCaptureState { std::string device_type; int64_t device_id; }; -class Copy : public OpExprGradFunction { +class Copy : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -38,14 +38,14 @@ class Copy : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Capture(CopyOpExprInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrMap& attrs) const override { + Maybe Capture(CopyCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + const AttrMap& attrs) const override { ctx->device_type = JUST(inputs.at(0)->device())->type(); ctx->device_id = JUST(inputs.at(0)->device())->device_id(); return Maybe::Ok(); } - Maybe Apply(const CopyOpExprInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const CopyCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(1); MutableAttrMap attrs; diff --git a/oneflow/core/autograd/gradient_funcs/ctc_loss.cpp b/oneflow/core/autograd/gradient_funcs/ctc_loss.cpp index 252a74946e0..76166b85dda 100644 --- a/oneflow/core/autograd/gradient_funcs/ctc_loss.cpp +++ b/oneflow/core/autograd/gradient_funcs/ctc_loss.cpp @@ -23,18 +23,18 @@ limitations under the License. namespace oneflow { namespace one { -struct CTCLossInterpState : public OpExprInterpState { +struct CTCLossCaptureState : public AutoGradCaptureState { int32_t blank; bool zero_infinity; bool requires_grad; }; -class CTCLoss : public OpExprGradFunction { +class CTCLoss : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(CTCLossInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(CTCLossCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const CTCLossInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const CTCLossCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -51,7 +51,7 @@ Maybe CTCLoss::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe CTCLoss::Capture(CTCLossInterpState* ctx, const TensorTuple& inputs, +Maybe CTCLoss::Capture(CTCLossCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -71,7 +71,7 @@ Maybe CTCLoss::Capture(CTCLossInterpState* ctx, const TensorTuple& inputs, return Maybe::Ok(); } -Maybe CTCLoss::Apply(const CTCLossInterpState* ctx, const TensorTuple& out_grads, +Maybe CTCLoss::Apply(const CTCLossCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 2); diff --git a/oneflow/core/autograd/gradient_funcs/deconv.cpp b/oneflow/core/autograd/gradient_funcs/deconv.cpp index fa2e92551f4..c79ad41db51 100644 --- a/oneflow/core/autograd/gradient_funcs/deconv.cpp +++ b/oneflow/core/autograd/gradient_funcs/deconv.cpp @@ -23,17 +23,17 @@ limitations under the License. namespace oneflow { namespace one { -struct DeConvolutionNdInterpState : public OpExprInterpState { +struct DeConvolutionNdCaptureState : public AutoGradCaptureState { bool weight_requires_grad = false; bool activation_requires_grad = false; }; -class DeConvolutionNd : public OpExprGradFunction { +class DeConvolutionNd : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(DeConvolutionNdInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(DeConvolutionNdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const DeConvolutionNdInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const DeConvolutionNdCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -71,7 +71,7 @@ Maybe DeConvolutionNd::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe DeConvolutionNd::Capture(DeConvolutionNdInterpState* ctx, const TensorTuple& inputs, +Maybe DeConvolutionNd::Capture(DeConvolutionNdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->activation_requires_grad = inputs.at(0)->requires_grad(); ctx->weight_requires_grad = inputs.at(1)->requires_grad(); @@ -84,7 +84,7 @@ Maybe DeConvolutionNd::Capture(DeConvolutionNdInterpState* ctx, const Tens return Maybe::Ok(); } -Maybe DeConvolutionNd::Apply(const DeConvolutionNdInterpState* ctx, +Maybe DeConvolutionNd::Apply(const DeConvolutionNdCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { in_grads->resize(2); if (ctx->activation_requires_grad) { diff --git a/oneflow/core/autograd/gradient_funcs/diag.cpp b/oneflow/core/autograd/gradient_funcs/diag.cpp index cfd0aee9daf..ccee46a3c5b 100644 --- a/oneflow/core/autograd/gradient_funcs/diag.cpp +++ b/oneflow/core/autograd/gradient_funcs/diag.cpp @@ -20,17 +20,17 @@ limitations under the License. namespace oneflow { namespace one { -struct DiagInterpState : public OpExprInterpState { +struct DiagCaptureState : public AutoGradCaptureState { bool requires_grad; int32_t diagonal; }; -class Diag : public OpExprGradFunction { +class Diag : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(DiagInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + Maybe Capture(DiagCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const DiagInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const DiagCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -44,7 +44,7 @@ Maybe Diag::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe Diag::Capture(DiagInterpState* ctx, const TensorTuple& inputs, +Maybe Diag::Capture(DiagCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(outputs.size(), 1); ctx->requires_grad = inputs.at(0)->requires_grad(); @@ -55,7 +55,7 @@ Maybe Diag::Capture(DiagInterpState* ctx, const TensorTuple& inputs, return Maybe::Ok(); } -Maybe Diag::Apply(const DiagInterpState* ctx, const TensorTuple& out_grads, +Maybe Diag::Apply(const DiagCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(2); diff --git a/oneflow/core/autograd/gradient_funcs/dim_gather.cpp b/oneflow/core/autograd/gradient_funcs/dim_gather.cpp index 4ae5b63e960..b6a39930b41 100644 --- a/oneflow/core/autograd/gradient_funcs/dim_gather.cpp +++ b/oneflow/core/autograd/gradient_funcs/dim_gather.cpp @@ -22,17 +22,17 @@ limitations under the License. namespace oneflow { namespace one { -struct DimGatherInterpState : public OpExprInterpState { +struct DimGatherCaptureState : public AutoGradCaptureState { int32_t dim; bool requires_grad; }; -class DimGather : public OpExprGradFunction { +class DimGather : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(DimGatherInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(DimGatherCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const DimGatherInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const DimGatherCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -49,7 +49,7 @@ Maybe DimGather::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe DimGather::Capture(DimGatherInterpState* ctx, const TensorTuple& inputs, +Maybe DimGather::Capture(DimGatherCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -62,7 +62,7 @@ Maybe DimGather::Capture(DimGatherInterpState* ctx, const TensorTuple& inp return Maybe::Ok(); } -Maybe DimGather::Apply(const DimGatherInterpState* ctx, const TensorTuple& out_grads, +Maybe DimGather::Apply(const DimGatherCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); diff --git a/oneflow/core/autograd/gradient_funcs/dim_scatter.cpp b/oneflow/core/autograd/gradient_funcs/dim_scatter.cpp index 6bda00e3abc..f443de1c3d1 100644 --- a/oneflow/core/autograd/gradient_funcs/dim_scatter.cpp +++ b/oneflow/core/autograd/gradient_funcs/dim_scatter.cpp @@ -23,7 +23,7 @@ limitations under the License. namespace oneflow { namespace one { -struct DimScatterInterpState : public OpExprInterpState { +struct DimScatterCaptureState : public AutoGradCaptureState { int32_t dim; bool input_requires_grad; bool src_requires_grad; @@ -32,14 +32,14 @@ struct DimScatterInterpState : public OpExprInterpState { enum SCATTER_TYPE { SCATTER_UPDATE, SCATTER_ADD }; template -class DimScatter : public OpExprGradFunction { +class DimScatter : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(DimScatterInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(DimScatterCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const DimScatterInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const DimScatterCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; - Maybe ApplyCommon(const DimScatterInterpState* ctx, const TensorTuple& out_grads, + Maybe ApplyCommon(const DimScatterCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const; private: @@ -55,7 +55,7 @@ Maybe DimScatter::Init(const OpExpr& op) { } template -Maybe DimScatter::Capture(DimScatterInterpState* ctx, const TensorTuple& inputs, +Maybe DimScatter::Capture(DimScatterCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 3); CHECK_EQ_OR_RETURN(outputs.size(), 1); @@ -72,7 +72,7 @@ Maybe DimScatter::Capture(DimScatterInterpState* ctx, const TensorTuple } template -Maybe DimScatter::ApplyCommon(const DimScatterInterpState* ctx, +Maybe DimScatter::ApplyCommon(const DimScatterCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { const std::shared_ptr& index = ctx->SavedTensors().at(0); @@ -85,7 +85,7 @@ Maybe DimScatter::ApplyCommon(const DimScatterInterpState* ctx, } template<> -Maybe DimScatter::Apply(const DimScatterInterpState* ctx, +Maybe DimScatter::Apply(const DimScatterCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe::Ok(); } @@ -101,7 +101,7 @@ Maybe DimScatter::Apply(const DimScatterInte } template<> -Maybe DimScatter::Apply(const DimScatterInterpState* ctx, +Maybe DimScatter::Apply(const DimScatterCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe::Ok(); } @@ -114,12 +114,12 @@ Maybe DimScatter::Apply(const DimScatterInterpS return Maybe::Ok(); } -class DimScatterUpdateScalar : public OpExprGradFunction { +class DimScatterUpdateScalar : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(DimScatterInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(DimScatterCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const DimScatterInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const DimScatterCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -134,7 +134,7 @@ Maybe DimScatterUpdateScalar::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe DimScatterUpdateScalar::Capture(DimScatterInterpState* ctx, const TensorTuple& inputs, +Maybe DimScatterUpdateScalar::Capture(DimScatterCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 2); @@ -150,7 +150,7 @@ Maybe DimScatterUpdateScalar::Capture(DimScatterInterpState* ctx, const Te return Maybe::Ok(); } -Maybe DimScatterUpdateScalar::Apply(const DimScatterInterpState* ctx, +Maybe DimScatterUpdateScalar::Apply(const DimScatterCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->input_requires_grad) { return Maybe::Ok(); } diff --git a/oneflow/core/autograd/gradient_funcs/dropout.cpp b/oneflow/core/autograd/gradient_funcs/dropout.cpp index c1f6e76914e..d79cf0d8aef 100644 --- a/oneflow/core/autograd/gradient_funcs/dropout.cpp +++ b/oneflow/core/autograd/gradient_funcs/dropout.cpp @@ -22,17 +22,17 @@ limitations under the License. namespace oneflow { namespace one { -struct DropoutInterpState : public OpExprInterpState { +struct DropoutCaptureState : public AutoGradCaptureState { bool requires_grad; float scale; }; -class Dropout : public OpExprGradFunction { +class Dropout : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(DropoutInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(DropoutCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const DropoutInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const DropoutCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -49,7 +49,7 @@ Maybe Dropout::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe Dropout::Capture(DropoutInterpState* ctx, const TensorTuple& inputs, +Maybe Dropout::Capture(DropoutCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->requires_grad = inputs.at(0)->requires_grad(); @@ -62,7 +62,7 @@ Maybe Dropout::Capture(DropoutInterpState* ctx, const TensorTuple& inputs, return Maybe::Ok(); } -Maybe Dropout::Apply(const DropoutInterpState* ctx, const TensorTuple& out_grads, +Maybe Dropout::Apply(const DropoutCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); diff --git a/oneflow/core/autograd/gradient_funcs/eager_nccl_broadcast.cpp b/oneflow/core/autograd/gradient_funcs/eager_nccl_broadcast.cpp index bb6587f70fc..49bde68c7b8 100644 --- a/oneflow/core/autograd/gradient_funcs/eager_nccl_broadcast.cpp +++ b/oneflow/core/autograd/gradient_funcs/eager_nccl_broadcast.cpp @@ -52,12 +52,12 @@ Maybe FindOrCreatEagerNcclReduceOpExpr(Symbol par } // namespace -struct EagerNcclBroadcastOpExprInterpState : public OpExprInterpState { +struct EagerNcclBroadcastCaptureState : public AutoGradCaptureState { Symbol parallel_desc; int64_t root; }; -class EagerNcclBroadcast : public OpExprGradFunction { +class EagerNcclBroadcast : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -65,7 +65,7 @@ class EagerNcclBroadcast : public OpExprGradFunction::Ok(); } - Maybe Capture(EagerNcclBroadcastOpExprInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(EagerNcclBroadcastCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const OpExprInterpContext& interp_ctx) const override { ctx->root = JUST(interp_ctx.attrs.GetAttr("root")); @@ -73,7 +73,7 @@ class EagerNcclBroadcast : public OpExprGradFunction::Ok(); } - Maybe Apply(const EagerNcclBroadcastOpExprInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const EagerNcclBroadcastCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& grad_op = JUST(FindOrCreatEagerNcclReduceOpExpr(ctx->parallel_desc, ctx->root)); in_grads->resize(1); diff --git a/oneflow/core/autograd/gradient_funcs/elementwise_minimum_maximum.cpp b/oneflow/core/autograd/gradient_funcs/elementwise_minimum_maximum.cpp index e155c6310b9..5d95292ed85 100644 --- a/oneflow/core/autograd/gradient_funcs/elementwise_minimum_maximum.cpp +++ b/oneflow/core/autograd/gradient_funcs/elementwise_minimum_maximum.cpp @@ -22,14 +22,14 @@ limitations under the License. namespace oneflow { namespace one { -struct ElementwiseXimumOpExprInterpState : public OpExprInterpState { +struct ElementwiseXimumCaptureState : public AutoGradCaptureState { bool x_requires_grad; bool y_requires_grad; }; -class ElementwiseXimumOp : public OpExprGradFunction { +class ElementwiseXimumOp : public OpExprGradFunction { public: - Maybe Capture(ElementwiseXimumOpExprInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(ElementwiseXimumCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->y_requires_grad = inputs.at(1)->requires_grad(); @@ -38,7 +38,7 @@ class ElementwiseXimumOp : public OpExprGradFunction::Ok(); } - Maybe Apply(const ElementwiseXimumOpExprInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const ElementwiseXimumCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!(ctx->x_requires_grad || ctx->y_requires_grad)) { return Maybe::Ok(); } diff --git a/oneflow/core/autograd/gradient_funcs/expand.cpp b/oneflow/core/autograd/gradient_funcs/expand.cpp index f735c1d0c20..b5e8fa7f25b 100644 --- a/oneflow/core/autograd/gradient_funcs/expand.cpp +++ b/oneflow/core/autograd/gradient_funcs/expand.cpp @@ -22,18 +22,18 @@ limitations under the License. namespace oneflow { namespace one { -struct ExpandInterpState : public OpExprInterpState { +struct ExpandCaptureState : public AutoGradCaptureState { std::vector out_shape; std::vector stride; bool requires_grad; }; -class Expand : public OpExprGradFunction { +class Expand : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(ExpandInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, - const AttrMap& attrs) const override; - Maybe Apply(const ExpandInterpState* ctx, const TensorTuple& out_grads, + Maybe Capture(ExpandCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override; + Maybe Apply(const ExpandCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -52,7 +52,7 @@ Maybe Expand::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe Expand::Capture(ExpandInterpState* ctx, const TensorTuple& inputs, +Maybe Expand::Capture(ExpandCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -63,7 +63,7 @@ Maybe Expand::Capture(ExpandInterpState* ctx, const TensorTuple& inputs, return Maybe::Ok(); } -Maybe Expand::Apply(const ExpandInterpState* ctx, const TensorTuple& out_grads, +Maybe Expand::Apply(const ExpandCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); diff --git a/oneflow/core/autograd/gradient_funcs/fake_quantization.cpp b/oneflow/core/autograd/gradient_funcs/fake_quantization.cpp index 064b4875e23..006d9c7c24c 100644 --- a/oneflow/core/autograd/gradient_funcs/fake_quantization.cpp +++ b/oneflow/core/autograd/gradient_funcs/fake_quantization.cpp @@ -18,22 +18,22 @@ limitations under the License. namespace oneflow { namespace one { -struct FakeQuantizationInterpState : public OpExprInterpState { +struct FakeQuantizationCaptureState : public AutoGradCaptureState { bool requires_grad; }; -class FakeQuantization : public OpExprGradFunction { +class FakeQuantization : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(FakeQuantizationInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(FakeQuantizationCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 3); ctx->requires_grad = inputs.at(0)->requires_grad(); return Maybe::Ok(); } - Maybe Apply(const FakeQuantizationInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const FakeQuantizationCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(3); diff --git a/oneflow/core/autograd/gradient_funcs/flatten.cpp b/oneflow/core/autograd/gradient_funcs/flatten.cpp index e8effbccd17..4b0f46bb7bd 100644 --- a/oneflow/core/autograd/gradient_funcs/flatten.cpp +++ b/oneflow/core/autograd/gradient_funcs/flatten.cpp @@ -23,16 +23,16 @@ limitations under the License. namespace oneflow { namespace one { -struct FlattenInterpState : public OpExprInterpState { +struct FlattenCaptureState : public AutoGradCaptureState { bool requires_grad; }; -class Flatten : public OpExprGradFunction { +class Flatten : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(FlattenInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(FlattenCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const FlattenInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const FlattenCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -47,7 +47,7 @@ Maybe Flatten::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe Flatten::Capture(FlattenInterpState* ctx, const TensorTuple& inputs, +Maybe Flatten::Capture(FlattenCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -55,7 +55,7 @@ Maybe Flatten::Capture(FlattenInterpState* ctx, const TensorTuple& inputs, return Maybe::Ok(); } -Maybe Flatten::Apply(const FlattenInterpState* ctx, const TensorTuple& out_grads, +Maybe Flatten::Apply(const FlattenCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); diff --git a/oneflow/core/autograd/gradient_funcs/flip.cpp b/oneflow/core/autograd/gradient_funcs/flip.cpp index 1022e4900a2..655821f4bd0 100644 --- a/oneflow/core/autograd/gradient_funcs/flip.cpp +++ b/oneflow/core/autograd/gradient_funcs/flip.cpp @@ -20,17 +20,17 @@ limitations under the License. namespace oneflow { namespace one { -struct FlipInterpState : public OpExprInterpState { +struct FlipCaptureState : public AutoGradCaptureState { bool requires_grad; std::vector dims; }; -class Flip : public OpExprGradFunction { +class Flip : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(FlipInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + Maybe Capture(FlipCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const FlipInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const FlipCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -44,7 +44,7 @@ Maybe Flip::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe Flip::Capture(FlipInterpState* ctx, const TensorTuple& inputs, +Maybe Flip::Capture(FlipCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -53,7 +53,7 @@ Maybe Flip::Capture(FlipInterpState* ctx, const TensorTuple& inputs, return Maybe::Ok(); } -Maybe Flip::Apply(const FlipInterpState* ctx, const TensorTuple& out_grads, +Maybe Flip::Apply(const FlipCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); diff --git a/oneflow/core/autograd/gradient_funcs/gather.cpp b/oneflow/core/autograd/gradient_funcs/gather.cpp index 0f83aceb71f..042f193f3c0 100644 --- a/oneflow/core/autograd/gradient_funcs/gather.cpp +++ b/oneflow/core/autograd/gradient_funcs/gather.cpp @@ -23,17 +23,17 @@ limitations under the License. namespace oneflow { namespace one { -struct GatherInterpState : public OpExprInterpState { +struct GatherCaptureState : public AutoGradCaptureState { int64_t axis; bool requires_grad; }; -class Gather : public OpExprGradFunction { +class Gather : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(GatherInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, - const AttrMap& attrs) const override; - Maybe Apply(const GatherInterpState* ctx, const TensorTuple& out_grads, + Maybe Capture(GatherCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override; + Maybe Apply(const GatherCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -47,7 +47,7 @@ Maybe Gather::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe Gather::Capture(GatherInterpState* ctx, const TensorTuple& inputs, +Maybe Gather::Capture(GatherCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -60,7 +60,7 @@ Maybe Gather::Capture(GatherInterpState* ctx, const TensorTuple& inputs, return Maybe::Ok(); } -Maybe Gather::Apply(const GatherInterpState* ctx, const TensorTuple& out_grads, +Maybe Gather::Apply(const GatherCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); diff --git a/oneflow/core/autograd/gradient_funcs/gather_nd.cpp b/oneflow/core/autograd/gradient_funcs/gather_nd.cpp index 3ba5b07a21e..84764cb953c 100644 --- a/oneflow/core/autograd/gradient_funcs/gather_nd.cpp +++ b/oneflow/core/autograd/gradient_funcs/gather_nd.cpp @@ -19,15 +19,15 @@ limitations under the License. namespace oneflow { namespace one { -struct GatherNdInterpState : public OpExprInterpState { +struct GatherNdCaptureState : public AutoGradCaptureState { bool requires_grad; }; -class GatherNd : public OpExprGradFunction { +class GatherNd : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(GatherNdInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(GatherNdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); CHECK_EQ_OR_RETURN(outputs.size(), 1); @@ -39,7 +39,7 @@ class GatherNd : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const GatherNdInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const GatherNdCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(2); diff --git a/oneflow/core/autograd/gradient_funcs/hierarchical_parallel_cast.cpp b/oneflow/core/autograd/gradient_funcs/hierarchical_parallel_cast.cpp index 8dc2bce6207..c591adb49f6 100644 --- a/oneflow/core/autograd/gradient_funcs/hierarchical_parallel_cast.cpp +++ b/oneflow/core/autograd/gradient_funcs/hierarchical_parallel_cast.cpp @@ -49,12 +49,11 @@ Maybe FindOrCreatHierarchicalParallelCastOpExpr(Symbol nd_sbp; }; -class HerarchicalParallelCast - : public OpExprGradFunction { +class HerarchicalParallelCast : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -62,14 +61,14 @@ class HerarchicalParallelCast return Maybe::Ok(); } - Maybe Capture(HerarchicalParallelCastOpExprInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(HerarchicalParallelCastCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->nd_sbp = JUST(inputs.at(0)->nd_sbp()); return Maybe::Ok(); } - Maybe Apply(const HerarchicalParallelCastOpExprInterpState* ctx, - const TensorTuple& out_grads, TensorTuple* in_grads) const override { + Maybe Apply(const HerarchicalParallelCastCaptureState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { const auto& grad_op = JUST(FindOrCreatHierarchicalParallelCastOpExpr(ctx->nd_sbp)); CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); diff --git a/oneflow/core/autograd/gradient_funcs/identity.cpp b/oneflow/core/autograd/gradient_funcs/identity.cpp index c47a14c5f51..0c929f0284b 100644 --- a/oneflow/core/autograd/gradient_funcs/identity.cpp +++ b/oneflow/core/autograd/gradient_funcs/identity.cpp @@ -18,22 +18,22 @@ limitations under the License. namespace oneflow { namespace one { -struct IdentityInterpState : public OpExprInterpState { +struct IdentityCaptureState : public AutoGradCaptureState { bool requires_grad; }; -class Identity : public OpExprGradFunction { +class Identity : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(IdentityInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(IdentityCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); ctx->requires_grad = inputs.at(0)->requires_grad(); return Maybe::Ok(); } - Maybe Apply(const IdentityInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const IdentityCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); diff --git a/oneflow/core/autograd/gradient_funcs/l2_normalize.cpp b/oneflow/core/autograd/gradient_funcs/l2_normalize.cpp index 5dc22de5b8f..bfa7d5687e6 100644 --- a/oneflow/core/autograd/gradient_funcs/l2_normalize.cpp +++ b/oneflow/core/autograd/gradient_funcs/l2_normalize.cpp @@ -22,18 +22,18 @@ limitations under the License. namespace oneflow { namespace one { -struct L2NormalizeInterpState : public OpExprInterpState { +struct L2NormalizeCaptureState : public AutoGradCaptureState { int64_t axis; float epsilon; bool requires_grad; }; -class L2Normalize : public OpExprGradFunction { +class L2Normalize : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(L2NormalizeInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(L2NormalizeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const L2NormalizeInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const L2NormalizeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -47,7 +47,7 @@ Maybe L2Normalize::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe L2Normalize::Capture(L2NormalizeInterpState* ctx, const TensorTuple& inputs, +Maybe L2Normalize::Capture(L2NormalizeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -61,7 +61,7 @@ Maybe L2Normalize::Capture(L2NormalizeInterpState* ctx, const TensorTuple& return Maybe::Ok(); } -Maybe L2Normalize::Apply(const L2NormalizeInterpState* ctx, const TensorTuple& out_grads, +Maybe L2Normalize::Apply(const L2NormalizeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } in_grads->resize(1); diff --git a/oneflow/core/autograd/gradient_funcs/layer_norm.cpp b/oneflow/core/autograd/gradient_funcs/layer_norm.cpp index fcaa5402bb3..fc950643853 100644 --- a/oneflow/core/autograd/gradient_funcs/layer_norm.cpp +++ b/oneflow/core/autograd/gradient_funcs/layer_norm.cpp @@ -23,7 +23,7 @@ limitations under the License. namespace oneflow { namespace one { -struct LayerNormInterpState : public OpExprInterpState { +struct LayerNormCaptureState : public AutoGradCaptureState { bool center; bool scale; @@ -47,14 +47,14 @@ struct LayerNormInterpState : public OpExprInterpState { // y, mean, inv_variance, [normalized] = // layer_norm(x, [beta], [gamma], center=False, scale=False, begin_norm_axis=1, // begin_params_axis=-1, epsilon=1e-5) -class LayerNorm : public OpExprGradFunction { +class LayerNorm : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(LayerNormInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(LayerNormCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const LayerNormInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const LayerNormCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -73,7 +73,7 @@ Maybe LayerNorm::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe LayerNorm::Capture(LayerNormInterpState* ctx, const TensorTuple& inputs, +Maybe LayerNorm::Capture(LayerNormCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->center = JUST(composed_attrs.GetAttr("center")); @@ -101,7 +101,7 @@ Maybe LayerNorm::Capture(LayerNormInterpState* ctx, const TensorTuple& inp return Maybe::Ok(); } -Maybe LayerNorm::Apply(const LayerNormInterpState* ctx, const TensorTuple& out_grads, +Maybe LayerNorm::Apply(const LayerNormCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { const auto& saved_tensors = ctx->SavedTensors(); in_grads->resize(ctx->center + ctx->scale + 1); diff --git a/oneflow/core/autograd/gradient_funcs/logsoftmax.cpp b/oneflow/core/autograd/gradient_funcs/logsoftmax.cpp index fc247d18883..ef50dbff2d1 100644 --- a/oneflow/core/autograd/gradient_funcs/logsoftmax.cpp +++ b/oneflow/core/autograd/gradient_funcs/logsoftmax.cpp @@ -23,16 +23,16 @@ limitations under the License. namespace oneflow { namespace one { -struct LogSoftmaxInterpState : public OpExprInterpState { +struct LogSoftmaxCaptureState : public AutoGradCaptureState { bool requires_grad; }; -class LogSoftmax : public OpExprGradFunction { +class LogSoftmax : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(LogSoftmaxInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(LogSoftmaxCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const LogSoftmaxInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const LogSoftmaxCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -53,7 +53,7 @@ Maybe LogSoftmax::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe LogSoftmax::Capture(LogSoftmaxInterpState* ctx, const TensorTuple& inputs, +Maybe LogSoftmax::Capture(LogSoftmaxCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ComposedAttrMap composed_attrs(attrs, base_attrs_); CHECK_EQ_OR_RETURN(inputs.size(), 1); @@ -65,7 +65,7 @@ Maybe LogSoftmax::Capture(LogSoftmaxInterpState* ctx, const TensorTuple& i return Maybe::Ok(); } -Maybe LogSoftmax::Apply(const LogSoftmaxInterpState* ctx, const TensorTuple& out_grads, +Maybe LogSoftmax::Apply(const LogSoftmaxCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) return Maybe::Ok(); CHECK_EQ_OR_RETURN(out_grads.size(), 2); diff --git a/oneflow/core/autograd/gradient_funcs/math_binary_op.cpp b/oneflow/core/autograd/gradient_funcs/math_binary_op.cpp index df653f3de16..bdb19599c66 100644 --- a/oneflow/core/autograd/gradient_funcs/math_binary_op.cpp +++ b/oneflow/core/autograd/gradient_funcs/math_binary_op.cpp @@ -23,13 +23,13 @@ limitations under the License. namespace oneflow { namespace one { -struct BinaryMathOpExprInterpState : public OpExprInterpState { +struct BinaryMathCaptureState : public AutoGradCaptureState { bool x_requires_grad; bool y_requires_grad; }; -class BinaryMathOp : public OpExprGradFunction { - Maybe Capture(BinaryMathOpExprInterpState* ctx, const TensorTuple& inputs, +class BinaryMathOp : public OpExprGradFunction { + Maybe Capture(BinaryMathCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->y_requires_grad = inputs.at(1)->requires_grad(); @@ -38,7 +38,7 @@ class BinaryMathOp : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const BinaryMathOpExprInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const BinaryMathCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!(ctx->x_requires_grad || ctx->y_requires_grad)) { return Maybe::Ok(); } diff --git a/oneflow/core/autograd/gradient_funcs/math_unary_op.cpp b/oneflow/core/autograd/gradient_funcs/math_unary_op.cpp index 797f9112b9d..5a383cbecf0 100644 --- a/oneflow/core/autograd/gradient_funcs/math_unary_op.cpp +++ b/oneflow/core/autograd/gradient_funcs/math_unary_op.cpp @@ -23,19 +23,19 @@ limitations under the License. namespace oneflow { namespace one { -struct UnaryMathOpExprInterpState : public OpExprInterpState { +struct UnaryMathCaptureState : public AutoGradCaptureState { bool x_requires_grad; }; -class UnaryMathOp : public OpExprGradFunction { - Maybe Capture(UnaryMathOpExprInterpState* ctx, const TensorTuple& inputs, +class UnaryMathOp : public OpExprGradFunction { + Maybe Capture(UnaryMathCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } - Maybe Apply(const UnaryMathOpExprInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const UnaryMathCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->x_requires_grad) { return Maybe::Ok(); } const auto& x = ctx->SavedTensors().at(0); diff --git a/oneflow/core/autograd/gradient_funcs/matmul.cpp b/oneflow/core/autograd/gradient_funcs/matmul.cpp index b684c2739c7..283768fb666 100644 --- a/oneflow/core/autograd/gradient_funcs/matmul.cpp +++ b/oneflow/core/autograd/gradient_funcs/matmul.cpp @@ -22,7 +22,7 @@ limitations under the License. namespace oneflow { namespace one { -struct MatmulInterpState : public OpExprInterpState { +struct MatmulCaptureState : public AutoGradCaptureState { bool transpose_a; bool transpose_b; double alpha; @@ -32,11 +32,11 @@ struct MatmulInterpState : public OpExprInterpState { size_t b_index; }; -class MatmulBase : public OpExprGradFunction { +class MatmulBase : public OpExprGradFunction { public: - Maybe Capture(MatmulInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, - const AttrMap& attrs) const override; - Maybe Apply(const MatmulInterpState* ctx, const TensorTuple& out_grads, + Maybe Capture(MatmulCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override; + Maybe Apply(const MatmulCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; protected: @@ -45,7 +45,7 @@ class MatmulBase : public OpExprGradFunction { std::shared_ptr grad_b_op_; }; -Maybe MatmulBase::Capture(MatmulInterpState* ctx, const TensorTuple& inputs, +Maybe MatmulBase::Capture(MatmulCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad_a = inputs.at(0)->requires_grad(); ctx->requires_grad_b = inputs.at(1)->requires_grad(); @@ -64,7 +64,7 @@ Maybe MatmulBase::Capture(MatmulInterpState* ctx, const TensorTuple& input return Maybe::Ok(); } -Maybe MatmulBase::Apply(const MatmulInterpState* ctx, const TensorTuple& out_grads, +Maybe MatmulBase::Apply(const MatmulCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); @@ -150,7 +150,7 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("batch_matmul", BatchMatmul); class BroadcastMatmul : public MatmulBase { public: Maybe Init(const OpExpr& op) override; - Maybe Apply(const MatmulInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const MatmulCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; }; @@ -168,7 +168,7 @@ Maybe BroadcastMatmul::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe BroadcastMatmul::Apply(const MatmulInterpState* ctx, const TensorTuple& out_grads, +Maybe BroadcastMatmul::Apply(const MatmulCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad_a && !ctx->requires_grad_b) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); diff --git a/oneflow/core/autograd/gradient_funcs/multiply.cpp b/oneflow/core/autograd/gradient_funcs/multiply.cpp index ff37dd5a40e..6ac80541b5e 100644 --- a/oneflow/core/autograd/gradient_funcs/multiply.cpp +++ b/oneflow/core/autograd/gradient_funcs/multiply.cpp @@ -19,19 +19,19 @@ limitations under the License. namespace oneflow { namespace one { -struct MultiplyInterpState : public OpExprInterpState { +struct MultiplyCaptureState : public AutoGradCaptureState { bool requires_grad_x; bool requires_grad_y; int32_t index_x; int32_t index_y; }; -class Multiply : public OpExprGradFunction { +class Multiply : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(MultiplyInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(MultiplyCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const MultiplyInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const MultiplyCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -45,7 +45,7 @@ Maybe Multiply::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe Multiply::Capture(MultiplyInterpState* ctx, const TensorTuple& inputs, +Maybe Multiply::Capture(MultiplyCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 2); ctx->requires_grad_x = inputs.at(0)->requires_grad(); @@ -55,7 +55,7 @@ Maybe Multiply::Capture(MultiplyInterpState* ctx, const TensorTuple& input return Maybe::Ok(); } -Maybe Multiply::Apply(const MultiplyInterpState* ctx, const TensorTuple& out_grads, +Maybe Multiply::Apply(const MultiplyCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(2); diff --git a/oneflow/core/autograd/gradient_funcs/narrow.cpp b/oneflow/core/autograd/gradient_funcs/narrow.cpp index ddaccd47ab2..c5e2a427856 100644 --- a/oneflow/core/autograd/gradient_funcs/narrow.cpp +++ b/oneflow/core/autograd/gradient_funcs/narrow.cpp @@ -21,14 +21,14 @@ limitations under the License. namespace oneflow { namespace one { -struct NarrowOpInterpState : public OpExprInterpState { +struct NarrowCaptureState : public AutoGradCaptureState { bool requires_grad; int64_t dim; int64_t start; int64_t length; }; -class NarrowOp : public OpExprGradFunction { +class Narrow : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -37,7 +37,7 @@ class NarrowOp : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Capture(NarrowOpInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(NarrowCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); @@ -52,7 +52,7 @@ class NarrowOp : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const NarrowOpInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const NarrowCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (ctx->requires_grad) { const auto& like = ctx->SavedTensors().at(0); @@ -67,7 +67,7 @@ class NarrowOp : public OpExprGradFunction { AttrMap base_attrs_; }; -REGISTER_OP_EXPR_GRAD_FUNCTION("narrow", NarrowOp); +REGISTER_OP_EXPR_GRAD_FUNCTION("narrow", Narrow); } // namespace one } // namespace oneflow diff --git a/oneflow/core/autograd/gradient_funcs/normalization.cpp b/oneflow/core/autograd/gradient_funcs/normalization.cpp index 53b6973841a..b3c71194f02 100644 --- a/oneflow/core/autograd/gradient_funcs/normalization.cpp +++ b/oneflow/core/autograd/gradient_funcs/normalization.cpp @@ -21,7 +21,7 @@ limitations under the License. namespace oneflow { namespace one { -struct NormalizationGradInterpState : public OpExprInterpState { +struct NormalizationGradCaptureState : public AutoGradCaptureState { int32_t axis; float epsilon; bool track_running_stats; @@ -39,7 +39,7 @@ struct NormalizationGradInterpState : public OpExprInterpState { // inference: // y = normalization(x, moving_mean, moving_variance, gamma, beta, axis=1, epsilon=0.01, // momentum=0.9) -class NormalizationGrad : public OpExprGradFunction { +class NormalizationGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -48,7 +48,7 @@ class NormalizationGrad : public OpExprGradFunction::Ok(); } - Maybe Capture(NormalizationGradInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(NormalizationGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->x_requires_grad = inputs.at(0)->requires_grad(); std::shared_ptr gamma, beta; @@ -81,7 +81,7 @@ class NormalizationGrad : public OpExprGradFunction::Ok(); } - Maybe Apply(const NormalizationGradInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const NormalizationGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& x = ctx->SavedTensors().at(0); // x const auto& gamma = ctx->SavedTensors().at(1); // gamma diff --git a/oneflow/core/autograd/gradient_funcs/padding.cpp b/oneflow/core/autograd/gradient_funcs/padding.cpp index 3d54edadeec..040549730e5 100644 --- a/oneflow/core/autograd/gradient_funcs/padding.cpp +++ b/oneflow/core/autograd/gradient_funcs/padding.cpp @@ -19,12 +19,12 @@ limitations under the License. namespace oneflow { namespace one { -struct Pad2dInterpState : public OpExprInterpState { +struct Pad2dCaptureState : public AutoGradCaptureState { bool requires_grad; std::vector paddings; }; -class Pad2d : public OpExprGradFunction { +class Pad2d : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const UserOpExpr* fw_op_expr = dynamic_cast(&op); @@ -33,7 +33,7 @@ class Pad2d : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Capture(Pad2dInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + Maybe Capture(Pad2dCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); @@ -51,7 +51,7 @@ class Pad2d : public OpExprGradFunction { class ReflectionPad2d : public Pad2d { public: - Maybe Apply(const Pad2dInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const Pad2dCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); @@ -64,7 +64,7 @@ class ReflectionPad2d : public Pad2d { class ReplicationPad2d : public Pad2d { public: - Maybe Apply(const Pad2dInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const Pad2dCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); @@ -78,13 +78,13 @@ class ReplicationPad2d : public Pad2d { REGISTER_OP_EXPR_GRAD_FUNCTION("reflection_pad2d", ReflectionPad2d); REGISTER_OP_EXPR_GRAD_FUNCTION("replication_pad2d", ReplicationPad2d); -struct ConstantPadNdInterpState : public OpExprInterpState { +struct ConstantPadNdCaptureState : public AutoGradCaptureState { bool requires_grad; std::vector paddings; functional::Scalar padding_value; }; -class ConstantPadNd : public OpExprGradFunction { +class ConstantPadNd : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const UserOpExpr* fw_op_expr = dynamic_cast(&op); @@ -93,7 +93,7 @@ class ConstantPadNd : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Capture(ConstantPadNdInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(ConstantPadNdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); @@ -112,7 +112,7 @@ class ConstantPadNd : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const ConstantPadNdInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const ConstantPadNdCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); diff --git a/oneflow/core/autograd/gradient_funcs/pool.cpp b/oneflow/core/autograd/gradient_funcs/pool.cpp index 49d9880a79f..e1ae9dc278d 100644 --- a/oneflow/core/autograd/gradient_funcs/pool.cpp +++ b/oneflow/core/autograd/gradient_funcs/pool.cpp @@ -26,7 +26,7 @@ namespace one { namespace { -struct PoolInterpState : public OpExprInterpState { +struct PoolCaptureState : public AutoGradCaptureState { bool requires_grad; size_t input_index; size_t output_index; @@ -40,16 +40,16 @@ struct PoolInterpState : public OpExprInterpState { bool ceil_mode; }; -class PoolNdGrad : public OpExprGradFunction { +class PoolNdGrad : public OpExprGradFunction { public: virtual ~PoolNdGrad() = default; - using OpExprGradFunction::Init; + using OpExprGradFunction::Init; Maybe Init(const OpExpr& op, const std::string& mode); - Maybe Capture(PoolInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + Maybe Capture(PoolCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const PoolInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const PoolCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -65,7 +65,7 @@ Maybe PoolNdGrad::Init(const OpExpr& op, const std::string& mode) { return Maybe::Ok(); } -Maybe PoolNdGrad::Capture(PoolInterpState* ctx, const TensorTuple& inputs, +Maybe PoolNdGrad::Capture(PoolCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -84,7 +84,7 @@ Maybe PoolNdGrad::Capture(PoolInterpState* ctx, const TensorTuple& inputs, return Maybe::Ok(); } -Maybe PoolNdGrad::Apply(const PoolInterpState* ctx, const TensorTuple& out_grads, +Maybe PoolNdGrad::Apply(const PoolCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); diff --git a/oneflow/core/autograd/gradient_funcs/pooling.cpp b/oneflow/core/autograd/gradient_funcs/pooling.cpp index 7111d91195e..e0c811870ca 100644 --- a/oneflow/core/autograd/gradient_funcs/pooling.cpp +++ b/oneflow/core/autograd/gradient_funcs/pooling.cpp @@ -27,7 +27,7 @@ namespace one { namespace { -struct PoolingInterpState : public OpExprInterpState { +struct PoolingCaptureState : public AutoGradCaptureState { bool requires_grad; size_t input_index; size_t output_index; @@ -42,16 +42,16 @@ struct PoolingInterpState : public OpExprInterpState { bool ceil_mode; }; -class PoolingNdGrad : public OpExprGradFunction { +class PoolingNdGrad : public OpExprGradFunction { public: virtual ~PoolingNdGrad() = default; - using OpExprGradFunction::Init; + using OpExprGradFunction::Init; Maybe Init(const OpExpr& op, const std::string& mode); - Maybe Capture(PoolingInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(PoolingCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const PoolingInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const PoolingCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -67,7 +67,7 @@ Maybe PoolingNdGrad::Init(const OpExpr& op, const std::string& mode) { return Maybe::Ok(); } -Maybe PoolingNdGrad::Capture(PoolingInterpState* ctx, const TensorTuple& inputs, +Maybe PoolingNdGrad::Capture(PoolingCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -87,7 +87,7 @@ Maybe PoolingNdGrad::Capture(PoolingInterpState* ctx, const TensorTuple& i return Maybe::Ok(); } -Maybe PoolingNdGrad::Apply(const PoolingInterpState* ctx, const TensorTuple& out_grads, +Maybe PoolingNdGrad::Apply(const PoolingCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_LE_OR_RETURN(out_grads.size(), 2); diff --git a/oneflow/core/autograd/gradient_funcs/reduce_ops.cpp b/oneflow/core/autograd/gradient_funcs/reduce_ops.cpp index 987d1f8f710..0a28fb336e3 100644 --- a/oneflow/core/autograd/gradient_funcs/reduce_ops.cpp +++ b/oneflow/core/autograd/gradient_funcs/reduce_ops.cpp @@ -22,39 +22,39 @@ limitations under the License. namespace oneflow { namespace one { -struct ReduceSumOpInterpState : public OpExprInterpState { +struct ReduceSumCaptureState : public AutoGradCaptureState { std::vector axis; }; -class ReduceSumOp : public OpExprGradFunction { +class ReduceSum : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(ReduceSumOpInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(ReduceSumCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const ReduceSumOpInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const ReduceSumCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; -Maybe ReduceSumOp::Init(const OpExpr& op) { +Maybe ReduceSum::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } -Maybe ReduceSumOp::Capture(ReduceSumOpInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrMap& attrs) const { +Maybe ReduceSum::Capture(ReduceSumCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const { ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->axis = JUST(composed_attrs.GetAttr>("axis")); ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } -Maybe ReduceSumOp::Apply(const ReduceSumOpInterpState* ctx, const TensorTuple& out_grads, - TensorTuple* in_grads) const { +Maybe ReduceSum::Apply(const ReduceSumCaptureState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const { const auto& input = ctx->SavedTensors().at(0); const auto& dy = out_grads.at(0); in_grads->resize(1); @@ -62,34 +62,34 @@ Maybe ReduceSumOp::Apply(const ReduceSumOpInterpState* ctx, const TensorTu return Maybe::Ok(); } -REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_sum", ReduceSumOp); +REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_sum", ReduceSum); -struct ReduceMaxOrMinOpInterpState : public OpExprInterpState { +struct ReduceMaxOrMinCaptureState : public AutoGradCaptureState { std::vector axis; bool keepdims; }; -class ReduceMaxOrMinOp : public OpExprGradFunction { +class ReduceMaxOrMin : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(ReduceMaxOrMinOpInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(ReduceMaxOrMinCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const ReduceMaxOrMinOpInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const ReduceMaxOrMinCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; -Maybe ReduceMaxOrMinOp::Init(const OpExpr& op) { +Maybe ReduceMaxOrMin::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } -Maybe ReduceMaxOrMinOp::Capture(ReduceMaxOrMinOpInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrMap& attrs) const { +Maybe ReduceMaxOrMin::Capture(ReduceMaxOrMinCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const { ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->axis = JUST(composed_attrs.GetAttr>("axis")); ctx->keepdims = JUST(composed_attrs.GetAttr("keepdims")); @@ -98,8 +98,8 @@ Maybe ReduceMaxOrMinOp::Capture(ReduceMaxOrMinOpInterpState* ctx, const Te return Maybe::Ok(); } -Maybe ReduceMaxOrMinOp::Apply(const ReduceMaxOrMinOpInterpState* ctx, - const TensorTuple& out_grads, TensorTuple* in_grads) const { +Maybe ReduceMaxOrMin::Apply(const ReduceMaxOrMinCaptureState* ctx, + const TensorTuple& out_grads, TensorTuple* in_grads) const { const auto& input = ctx->SavedTensors().at(0); const auto& output = ctx->SavedTensors().at(1); const auto& dy = out_grads.at(0); @@ -116,8 +116,8 @@ Maybe ReduceMaxOrMinOp::Apply(const ReduceMaxOrMinOpInterpState* ctx, return Maybe::Ok(); } -REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_min", ReduceMaxOrMinOp); -REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_max", ReduceMaxOrMinOp); +REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_min", ReduceMaxOrMin); +REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_max", ReduceMaxOrMin); } // namespace one } // namespace oneflow diff --git a/oneflow/core/autograd/gradient_funcs/reshape.cpp b/oneflow/core/autograd/gradient_funcs/reshape.cpp index dcc99995288..c743e4cb6e9 100644 --- a/oneflow/core/autograd/gradient_funcs/reshape.cpp +++ b/oneflow/core/autograd/gradient_funcs/reshape.cpp @@ -24,7 +24,7 @@ limitations under the License. namespace oneflow { namespace one { -class ReshapeOpExprGrad : public OpExprGradFunction { +class ReshapeOpExprGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -33,13 +33,13 @@ class ReshapeOpExprGrad : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Capture(OpExprInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, - const AttrMap& attrs) const override { + Maybe Capture(AutoGradCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } - Maybe Apply(const OpExprInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const AutoGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& saved_tensors = ctx->SavedTensors(); in_grads->resize(1); diff --git a/oneflow/core/autograd/gradient_funcs/scalar_add.cpp b/oneflow/core/autograd/gradient_funcs/scalar_add.cpp index b8db59221a9..3cfaaa7c0ce 100644 --- a/oneflow/core/autograd/gradient_funcs/scalar_add.cpp +++ b/oneflow/core/autograd/gradient_funcs/scalar_add.cpp @@ -19,22 +19,22 @@ limitations under the License. namespace oneflow { namespace one { -struct ScalarAddInterpState : public OpExprInterpState { +struct ScalarAddCaptureState : public AutoGradCaptureState { bool requires_grad; }; -class ScalarAdd : public OpExprGradFunction { +class ScalarAdd : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(ScalarAddInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(ScalarAddCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); ctx->requires_grad = inputs.at(0)->requires_grad(); return Maybe::Ok(); } - Maybe Apply(const ScalarAddInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const ScalarAddCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); diff --git a/oneflow/core/autograd/gradient_funcs/scalar_fmod.cpp b/oneflow/core/autograd/gradient_funcs/scalar_fmod.cpp index d5d1b45ece9..0922f391054 100644 --- a/oneflow/core/autograd/gradient_funcs/scalar_fmod.cpp +++ b/oneflow/core/autograd/gradient_funcs/scalar_fmod.cpp @@ -20,22 +20,22 @@ limitations under the License. namespace oneflow { namespace one { -struct ScalarFModGradInterpState : public OpExprInterpState { +struct ScalarFModGradCaptureState : public AutoGradCaptureState { bool requires_grad; }; -class ScalarFModGrad : public OpExprGradFunction { +class ScalarFModGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(ScalarFModGradInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(ScalarFModGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); ctx->requires_grad = inputs.at(0)->requires_grad(); return Maybe::Ok(); } - Maybe Apply(const ScalarFModGradInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const ScalarFModGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); diff --git a/oneflow/core/autograd/gradient_funcs/scalar_mul.cpp b/oneflow/core/autograd/gradient_funcs/scalar_mul.cpp index bc64ff10126..6f9942bb56e 100644 --- a/oneflow/core/autograd/gradient_funcs/scalar_mul.cpp +++ b/oneflow/core/autograd/gradient_funcs/scalar_mul.cpp @@ -20,12 +20,12 @@ limitations under the License. namespace oneflow { namespace one { -struct ScalarMulInterpState : public OpExprInterpState { +struct ScalarMulCaptureState : public AutoGradCaptureState { bool requires_grad; functional::Scalar operand; }; -class ScalarMul : public OpExprGradFunction { +class ScalarMul : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -34,7 +34,7 @@ class ScalarMul : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Capture(ScalarMulInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(ScalarMulCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); ctx->requires_grad = inputs.at(0)->requires_grad(); @@ -49,7 +49,7 @@ class ScalarMul : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const ScalarMulInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const ScalarMulCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); diff --git a/oneflow/core/autograd/gradient_funcs/scalar_pow.cpp b/oneflow/core/autograd/gradient_funcs/scalar_pow.cpp index 03f021aaad5..19946f3d228 100644 --- a/oneflow/core/autograd/gradient_funcs/scalar_pow.cpp +++ b/oneflow/core/autograd/gradient_funcs/scalar_pow.cpp @@ -22,12 +22,12 @@ limitations under the License. namespace oneflow { namespace one { -struct ScalarPowInterpState : public OpExprInterpState { +struct ScalarPowCaptureState : public AutoGradCaptureState { bool requires_grad; double exponent; }; -class ScalarPow : public OpExprGradFunction { +class ScalarPow : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -38,7 +38,7 @@ class ScalarPow : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Capture(ScalarPowInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(ScalarPowCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); @@ -51,7 +51,7 @@ class ScalarPow : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const ScalarPowInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const ScalarPowCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& x = ctx->SavedTensors().at(0); MutableAttrMap attrs; diff --git a/oneflow/core/autograd/gradient_funcs/scatter_nd.cpp b/oneflow/core/autograd/gradient_funcs/scatter_nd.cpp index 6f49d9df750..6f8119b236b 100644 --- a/oneflow/core/autograd/gradient_funcs/scatter_nd.cpp +++ b/oneflow/core/autograd/gradient_funcs/scatter_nd.cpp @@ -19,15 +19,15 @@ limitations under the License. namespace oneflow { namespace one { -struct ScatterNdInterpState : public OpExprInterpState { +struct ScatterNdCaptureState : public AutoGradCaptureState { bool requires_grad; }; -class ScatterNd : public OpExprGradFunction { +class ScatterNd : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(ScatterNdInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(ScatterNdCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); CHECK_EQ_OR_RETURN(outputs.size(), 1); @@ -38,7 +38,7 @@ class ScatterNd : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const ScatterNdInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const ScatterNdCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(2); diff --git a/oneflow/core/autograd/gradient_funcs/select_first.cpp b/oneflow/core/autograd/gradient_funcs/select_first.cpp index f14741fccb9..95ca051a336 100644 --- a/oneflow/core/autograd/gradient_funcs/select_first.cpp +++ b/oneflow/core/autograd/gradient_funcs/select_first.cpp @@ -23,22 +23,22 @@ limitations under the License. namespace oneflow { namespace one { -struct SelectFirstExprInterpState : public OpExprInterpState { +struct SelectFirstCaptureState : public AutoGradCaptureState { TensorTuple inputs; bool requires_grad; }; -class SelectFirst : public OpExprGradFunction { +class SelectFirst : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(SelectFirstExprInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(SelectFirstCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->inputs = inputs; return Maybe::Ok(); } - Maybe Apply(const SelectFirstExprInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const SelectFirstCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->at(0) = out_grads.at(0); for (int i = 1; i < ctx->inputs.size(); i++) { diff --git a/oneflow/core/autograd/gradient_funcs/slice.cpp b/oneflow/core/autograd/gradient_funcs/slice.cpp index 85b84562479..e5c9e226d79 100644 --- a/oneflow/core/autograd/gradient_funcs/slice.cpp +++ b/oneflow/core/autograd/gradient_funcs/slice.cpp @@ -22,14 +22,14 @@ limitations under the License. namespace oneflow { namespace one { -struct SliceOpExprInterpState : public OpExprInterpState { +struct SliceCaptureState : public AutoGradCaptureState { bool requires_grad; std::vector start; std::vector stop; std::vector step; }; -class Slice : public OpExprGradFunction { +class Slice : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -38,8 +38,8 @@ class Slice : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Capture(SliceOpExprInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrMap& attrs) const override { + Maybe Capture(SliceCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); ctx->requires_grad = inputs.at(0)->requires_grad(); @@ -53,7 +53,7 @@ class Slice : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const SliceOpExprInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const SliceCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& like = ctx->SavedTensors().at(0); @@ -67,7 +67,7 @@ class Slice : public OpExprGradFunction { AttrMap base_attrs_; }; -struct SliceUpdateOpExprInterpState : public OpExprInterpState { +struct SliceUpdateCaptureState : public AutoGradCaptureState { bool requires_grad_x; bool requires_grad_update; std::vector start; @@ -75,7 +75,7 @@ struct SliceUpdateOpExprInterpState : public OpExprInterpState { std::vector step; }; -class SliceUpdate : public OpExprGradFunction { +class SliceUpdate : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -85,7 +85,7 @@ class SliceUpdate : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Capture(SliceUpdateOpExprInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(SliceUpdateCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); CHECK_EQ_OR_RETURN(outputs.size(), 1); @@ -102,7 +102,7 @@ class SliceUpdate : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const SliceUpdateOpExprInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const SliceUpdateCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); diff --git a/oneflow/core/autograd/gradient_funcs/smoothl1_loss.cpp b/oneflow/core/autograd/gradient_funcs/smoothl1_loss.cpp index 05db4f22012..fcf6aa3e801 100644 --- a/oneflow/core/autograd/gradient_funcs/smoothl1_loss.cpp +++ b/oneflow/core/autograd/gradient_funcs/smoothl1_loss.cpp @@ -21,7 +21,7 @@ limitations under the License. namespace oneflow { namespace one { -struct SmoothL1LossInterpState : public OpExprInterpState { +struct SmoothL1LossCaptureState : public AutoGradCaptureState { std::string reduction; float beta; size_t prediction_index; @@ -29,7 +29,7 @@ struct SmoothL1LossInterpState : public OpExprInterpState { bool requires_grad; }; -class SmoothL1Loss : public OpExprGradFunction { +class SmoothL1Loss : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -38,7 +38,7 @@ class SmoothL1Loss : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Capture(SmoothL1LossInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(SmoothL1LossCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); ctx->requires_grad = inputs.at(0)->requires_grad(); // prediction @@ -52,7 +52,7 @@ class SmoothL1Loss : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Apply(const SmoothL1LossInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const SmoothL1LossCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(2); diff --git a/oneflow/core/autograd/gradient_funcs/softmax.cpp b/oneflow/core/autograd/gradient_funcs/softmax.cpp index 10799b045c9..f8dfeac04eb 100644 --- a/oneflow/core/autograd/gradient_funcs/softmax.cpp +++ b/oneflow/core/autograd/gradient_funcs/softmax.cpp @@ -22,16 +22,16 @@ limitations under the License. namespace oneflow { namespace one { -struct SoftmaxInterpState : public OpExprInterpState { +struct SoftmaxCaptureState : public AutoGradCaptureState { bool requires_grad; }; -class Softmax : public OpExprGradFunction { +class Softmax : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(SoftmaxInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(SoftmaxCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const SoftmaxInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const SoftmaxCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -46,7 +46,7 @@ Maybe Softmax::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe Softmax::Capture(SoftmaxInterpState* ctx, const TensorTuple& inputs, +Maybe Softmax::Capture(SoftmaxCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 1); ctx->requires_grad = inputs.at(0)->requires_grad(); @@ -57,7 +57,7 @@ Maybe Softmax::Capture(SoftmaxInterpState* ctx, const TensorTuple& inputs, return Maybe::Ok(); } -Maybe Softmax::Apply(const SoftmaxInterpState* ctx, const TensorTuple& out_grads, +Maybe Softmax::Apply(const SoftmaxCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) return Maybe::Ok(); CHECK_EQ_OR_RETURN(out_grads.size(), 1); diff --git a/oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp b/oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp index cb1c93c09f3..3e1dea3df22 100644 --- a/oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp +++ b/oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp @@ -23,16 +23,16 @@ limitations under the License. namespace oneflow { namespace one { -struct SparseSoftmaxCrossEntropyInterpState : public OpExprInterpState { +struct SparseSoftmaxCrossEntropyCaptureState : public AutoGradCaptureState { int64_t depth; }; -class SparseSoftmaxCrossEntropy : public OpExprGradFunction { +class SparseSoftmaxCrossEntropy : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(SparseSoftmaxCrossEntropyInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(SparseSoftmaxCrossEntropyCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const SparseSoftmaxCrossEntropyInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const SparseSoftmaxCrossEntropyCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -50,7 +50,7 @@ Maybe SparseSoftmaxCrossEntropy::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe SparseSoftmaxCrossEntropy::Capture(SparseSoftmaxCrossEntropyInterpState* ctx, +Maybe SparseSoftmaxCrossEntropy::Capture(SparseSoftmaxCrossEntropyCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { @@ -63,7 +63,7 @@ Maybe SparseSoftmaxCrossEntropy::Capture(SparseSoftmaxCrossEntropyInterpSt return Maybe::Ok(); } -Maybe SparseSoftmaxCrossEntropy::Apply(const SparseSoftmaxCrossEntropyInterpState* ctx, +Maybe SparseSoftmaxCrossEntropy::Apply(const SparseSoftmaxCrossEntropyCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 2); diff --git a/oneflow/core/autograd/gradient_funcs/split_like.cpp b/oneflow/core/autograd/gradient_funcs/split_like.cpp index 7d03cd5f20d..4424263384e 100644 --- a/oneflow/core/autograd/gradient_funcs/split_like.cpp +++ b/oneflow/core/autograd/gradient_funcs/split_like.cpp @@ -23,17 +23,17 @@ limitations under the License. namespace oneflow { namespace one { -struct SplitLikeInterpState : public OpExprInterpState { +struct SplitLikeCaptureState : public AutoGradCaptureState { int64_t max_dim_size; bool requires_grad; }; -class SplitLike : public OpExprGradFunction { +class SplitLike : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(SplitLikeInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(SplitLikeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const SplitLikeInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const SplitLikeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -60,7 +60,7 @@ Maybe SplitLike::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe SplitLike::Capture(SplitLikeInterpState* ctx, const TensorTuple& inputs, +Maybe SplitLike::Capture(SplitLikeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), outputs.size() + 1); ctx->requires_grad = inputs.at(0)->requires_grad(); @@ -73,7 +73,7 @@ Maybe SplitLike::Capture(SplitLikeInterpState* ctx, const TensorTuple& inp return Maybe::Ok(); } -Maybe SplitLike::Apply(const SplitLikeInterpState* ctx, const TensorTuple& out_grads, +Maybe SplitLike::Apply(const SplitLikeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { in_grads->resize(1); if (!ctx->requires_grad) { return Maybe::Ok(); } diff --git a/oneflow/core/autograd/gradient_funcs/squeeze.cpp b/oneflow/core/autograd/gradient_funcs/squeeze.cpp index a69a600394f..8d1e56ba5a3 100644 --- a/oneflow/core/autograd/gradient_funcs/squeeze.cpp +++ b/oneflow/core/autograd/gradient_funcs/squeeze.cpp @@ -22,16 +22,16 @@ limitations under the License. namespace oneflow { namespace one { -struct SqueezeInterpState : public OpExprInterpState { +struct SqueezeCaptureState : public AutoGradCaptureState { bool requires_grad; }; -class Squeeze : public OpExprGradFunction { +class Squeeze : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(SqueezeInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(SqueezeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const SqueezeInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const SqueezeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -48,7 +48,7 @@ Maybe Squeeze::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe Squeeze::Capture(SqueezeInterpState* ctx, const TensorTuple& inputs, +Maybe Squeeze::Capture(SqueezeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -57,7 +57,7 @@ Maybe Squeeze::Capture(SqueezeInterpState* ctx, const TensorTuple& inputs, return Maybe::Ok(); } -Maybe Squeeze::Apply(const SqueezeInterpState* ctx, const TensorTuple& out_grads, +Maybe Squeeze::Apply(const SqueezeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); diff --git a/oneflow/core/autograd/gradient_funcs/tensor_scalar_binary.cpp b/oneflow/core/autograd/gradient_funcs/tensor_scalar_binary.cpp index fa3a6264aff..f95225211a4 100644 --- a/oneflow/core/autograd/gradient_funcs/tensor_scalar_binary.cpp +++ b/oneflow/core/autograd/gradient_funcs/tensor_scalar_binary.cpp @@ -22,18 +22,18 @@ limitations under the License. namespace oneflow { namespace one { -struct TensorScalarInterpState : public OpExprInterpState { +struct TensorScalarCaptureState : public AutoGradCaptureState { bool x_requires_grad; bool scalar_requires_grad; }; -class TensorScalarAddOrSub : public OpExprGradFunction { +class TensorScalarAddOrSub : public OpExprGradFunction { public: TensorScalarAddOrSub() = default; virtual ~TensorScalarAddOrSub() = default; Maybe Init(const OpExpr& op) override; - Maybe Capture(TensorScalarInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; protected: @@ -55,7 +55,7 @@ Maybe TensorScalarAddOrSub::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe TensorScalarAddOrSub::Capture(TensorScalarInterpState* ctx, const TensorTuple& inputs, +Maybe TensorScalarAddOrSub::Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->scalar_requires_grad = inputs.at(1)->requires_grad(); @@ -64,7 +64,7 @@ Maybe TensorScalarAddOrSub::Capture(TensorScalarInterpState* ctx, const Te class TensorScalarAdd : public TensorScalarAddOrSub { public: - Maybe Apply(const TensorScalarInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const TensorScalarCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); if (ctx->x_requires_grad) { @@ -85,7 +85,7 @@ class TensorScalarAdd : public TensorScalarAddOrSub { class TensorScalarSub : public TensorScalarAddOrSub { public: - Maybe Apply(const TensorScalarInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const TensorScalarCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { in_grads->resize(2); if (ctx->x_requires_grad) { @@ -108,12 +108,12 @@ class TensorScalarSub : public TensorScalarAddOrSub { REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_add_by_tensor", TensorScalarAdd); REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_sub_by_tensor", TensorScalarSub); -class TensorScalarMul : public OpExprGradFunction { +class TensorScalarMul : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(TensorScalarInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const TensorScalarInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const TensorScalarCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -133,7 +133,7 @@ Maybe TensorScalarMul::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe TensorScalarMul::Capture(TensorScalarInterpState* ctx, const TensorTuple& inputs, +Maybe TensorScalarMul::Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->scalar_requires_grad = inputs.at(1)->requires_grad(); @@ -142,8 +142,8 @@ Maybe TensorScalarMul::Capture(TensorScalarInterpState* ctx, const TensorT return Maybe::Ok(); } -Maybe TensorScalarMul::Apply(const TensorScalarInterpState* ctx, const TensorTuple& out_grads, - TensorTuple* in_grads) const { +Maybe TensorScalarMul::Apply(const TensorScalarCaptureState* ctx, + const TensorTuple& out_grads, TensorTuple* in_grads) const { in_grads->resize(2); if (ctx->x_requires_grad) { const auto& scalar = ctx->SavedTensors().at(0); @@ -165,12 +165,12 @@ Maybe TensorScalarMul::Apply(const TensorScalarInterpState* ctx, const Ten REGISTER_OP_EXPR_GRAD_FUNCTION("scalar_mul_by_tensor", TensorScalarMul); -class TensorScalarDiv : public OpExprGradFunction { +class TensorScalarDiv : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(TensorScalarInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const TensorScalarInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const TensorScalarCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -188,7 +188,7 @@ Maybe TensorScalarDiv::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe TensorScalarDiv::Capture(TensorScalarInterpState* ctx, const TensorTuple& inputs, +Maybe TensorScalarDiv::Capture(TensorScalarCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->scalar_requires_grad = inputs.at(1)->requires_grad(); @@ -199,8 +199,8 @@ Maybe TensorScalarDiv::Capture(TensorScalarInterpState* ctx, const TensorT return Maybe::Ok(); } -Maybe TensorScalarDiv::Apply(const TensorScalarInterpState* ctx, const TensorTuple& out_grads, - TensorTuple* in_grads) const { +Maybe TensorScalarDiv::Apply(const TensorScalarCaptureState* ctx, + const TensorTuple& out_grads, TensorTuple* in_grads) const { in_grads->resize(2); if (ctx->x_requires_grad) { const auto& scalar = ctx->SavedTensors().at(0); diff --git a/oneflow/core/autograd/gradient_funcs/transpose.cpp b/oneflow/core/autograd/gradient_funcs/transpose.cpp index 41e3db11238..570132edbcb 100644 --- a/oneflow/core/autograd/gradient_funcs/transpose.cpp +++ b/oneflow/core/autograd/gradient_funcs/transpose.cpp @@ -22,17 +22,17 @@ limitations under the License. namespace oneflow { namespace one { -struct TransposeInterpState : public OpExprInterpState { +struct TransposeCaptureState : public AutoGradCaptureState { std::vector perm; bool requires_grad; }; -class Transpose : public OpExprGradFunction { +class Transpose : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(TransposeInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(TransposeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const TransposeInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const TransposeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -50,7 +50,7 @@ Maybe Transpose::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe Transpose::Capture(TransposeInterpState* ctx, const TensorTuple& inputs, +Maybe Transpose::Capture(TransposeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -60,7 +60,7 @@ Maybe Transpose::Capture(TransposeInterpState* ctx, const TensorTuple& inp return Maybe::Ok(); } -Maybe Transpose::Apply(const TransposeInterpState* ctx, const TensorTuple& out_grads, +Maybe Transpose::Apply(const TransposeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); diff --git a/oneflow/core/autograd/gradient_funcs/triu.cpp b/oneflow/core/autograd/gradient_funcs/triu.cpp index 0bcce6ac7da..ed04de8074d 100644 --- a/oneflow/core/autograd/gradient_funcs/triu.cpp +++ b/oneflow/core/autograd/gradient_funcs/triu.cpp @@ -20,17 +20,17 @@ limitations under the License. namespace oneflow { namespace one { -struct TriuInterpState : public OpExprInterpState { +struct TriuCaptureState : public AutoGradCaptureState { bool requires_grad; int64_t diagonal; }; -class Triu : public OpExprGradFunction { +class Triu : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(TriuInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + Maybe Capture(TriuCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const TriuInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const TriuCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -44,7 +44,7 @@ Maybe Triu::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe Triu::Capture(TriuInterpState* ctx, const TensorTuple& inputs, +Maybe Triu::Capture(TriuCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -53,7 +53,7 @@ Maybe Triu::Capture(TriuInterpState* ctx, const TensorTuple& inputs, return Maybe::Ok(); } -Maybe Triu::Apply(const TriuInterpState* ctx, const TensorTuple& out_grads, +Maybe Triu::Apply(const TriuCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); diff --git a/oneflow/core/autograd/gradient_funcs/unsqueeze.cpp b/oneflow/core/autograd/gradient_funcs/unsqueeze.cpp index 3faf92e3d5c..7246b3c1d5e 100644 --- a/oneflow/core/autograd/gradient_funcs/unsqueeze.cpp +++ b/oneflow/core/autograd/gradient_funcs/unsqueeze.cpp @@ -22,16 +22,16 @@ limitations under the License. namespace oneflow { namespace one { -struct UnsqueezeInterpState : public OpExprInterpState { +struct UnsqueezeCaptureState : public AutoGradCaptureState { bool requires_grad; }; -class Unsqueeze : public OpExprGradFunction { +class Unsqueeze : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(UnsqueezeInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(UnsqueezeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const UnsqueezeInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const UnsqueezeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -48,7 +48,7 @@ Maybe Unsqueeze::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe Unsqueeze::Capture(UnsqueezeInterpState* ctx, const TensorTuple& inputs, +Maybe Unsqueeze::Capture(UnsqueezeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -57,7 +57,7 @@ Maybe Unsqueeze::Capture(UnsqueezeInterpState* ctx, const TensorTuple& inp return Maybe::Ok(); } -Maybe Unsqueeze::Apply(const UnsqueezeInterpState* ctx, const TensorTuple& out_grads, +Maybe Unsqueeze::Apply(const UnsqueezeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); diff --git a/oneflow/core/autograd/gradient_funcs/upsample.cpp b/oneflow/core/autograd/gradient_funcs/upsample.cpp index 51ab7d58937..c6eeaca4f78 100644 --- a/oneflow/core/autograd/gradient_funcs/upsample.cpp +++ b/oneflow/core/autograd/gradient_funcs/upsample.cpp @@ -23,7 +23,7 @@ limitations under the License. namespace oneflow { namespace one { -struct UpsampleInterpState : public OpExprInterpState { +struct UpsampleCaptureState : public AutoGradCaptureState { bool requires_grad; float height_scale; float width_scale; @@ -32,12 +32,12 @@ struct UpsampleInterpState : public OpExprInterpState { std::string interpolation; }; -class Upsample : public OpExprGradFunction { +class Upsample : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(UpsampleInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(UpsampleCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const UpsampleInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const UpsampleCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: @@ -61,7 +61,7 @@ Maybe Upsample::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe Upsample::Capture(UpsampleInterpState* ctx, const TensorTuple& inputs, +Maybe Upsample::Capture(UpsampleCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -75,7 +75,7 @@ Maybe Upsample::Capture(UpsampleInterpState* ctx, const TensorTuple& input return Maybe::Ok(); } -Maybe Upsample::Apply(const UpsampleInterpState* ctx, const TensorTuple& out_grads, +Maybe Upsample::Apply(const UpsampleCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); @@ -94,18 +94,18 @@ Maybe Upsample::Apply(const UpsampleInterpState* ctx, const TensorTuple& o REGISTER_OP_EXPR_GRAD_FUNCTION("upsample", Upsample); -struct UpsampleNearest2DInterpState : public OpExprInterpState { +struct UpsampleNearest2DCaptureState : public AutoGradCaptureState { bool requires_grad; float height_scale; float width_scale; std::string data_format; }; -class UpsampleNearest2D : public OpExprGradFunction { +class UpsampleNearest2D : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(UpsampleNearest2DInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(UpsampleNearest2DCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); @@ -119,7 +119,7 @@ class UpsampleNearest2D : public OpExprGradFunction::Ok(); } - Maybe Apply(const UpsampleNearest2DInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const UpsampleNearest2DCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); @@ -138,7 +138,7 @@ class UpsampleNearest2D : public OpExprGradFunction { +class UpsampleBilinear2D : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(UpsampleBilinear2DInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(UpsampleBilinear2DCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); @@ -165,7 +165,7 @@ class UpsampleBilinear2D : public OpExprGradFunction::Ok(); } - Maybe Apply(const UpsampleBilinear2DInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const UpsampleBilinear2DCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); @@ -185,18 +185,18 @@ class UpsampleBilinear2D : public OpExprGradFunction { +class UpsampleLinear1D : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(UpsampleLinear1DInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(UpsampleLinear1DCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); @@ -210,7 +210,7 @@ class UpsampleLinear1D : public OpExprGradFunction return Maybe::Ok(); } - Maybe Apply(const UpsampleLinear1DInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const UpsampleLinear1DCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); @@ -229,17 +229,17 @@ class UpsampleLinear1D : public OpExprGradFunction REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_linear_1d", UpsampleLinear1D); -struct UpsampleNearest1DInterpState : public OpExprInterpState { +struct UpsampleNearest1DCaptureState : public AutoGradCaptureState { bool requires_grad; float scale_factor; std::string data_format; }; -class UpsampleNearest1D : public OpExprGradFunction { +class UpsampleNearest1D : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(UpsampleNearest1DInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(UpsampleNearest1DCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); @@ -252,7 +252,7 @@ class UpsampleNearest1D : public OpExprGradFunction::Ok(); } - Maybe Apply(const UpsampleNearest1DInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const UpsampleNearest1DCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); @@ -271,7 +271,7 @@ class UpsampleNearest1D : public OpExprGradFunction { +class UpsampleBicubic2D : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(UpsampleBicubic2DInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(UpsampleBicubic2DCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); @@ -298,7 +298,7 @@ class UpsampleBicubic2D : public OpExprGradFunction::Ok(); } - Maybe Apply(const UpsampleBicubic2DInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const UpsampleBicubic2DCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); @@ -317,7 +317,7 @@ class UpsampleBicubic2D : public OpExprGradFunction { +class UpsampleNearest3D : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(UpsampleNearest3DInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(UpsampleNearest3DCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); @@ -344,7 +344,7 @@ class UpsampleNearest3D : public OpExprGradFunction::Ok(); } - Maybe Apply(const UpsampleNearest3DInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const UpsampleNearest3DCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); @@ -364,7 +364,7 @@ class UpsampleNearest3D : public OpExprGradFunction { +class UpsampleTrilinear3D : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(UpsampleTrilinear3DInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(UpsampleTrilinear3DCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(outputs.size(), 1); @@ -393,7 +393,7 @@ class UpsampleTrilinear3D : public OpExprGradFunction::Ok(); } - Maybe Apply(const UpsampleTrilinear3DInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const UpsampleTrilinear3DCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); diff --git a/oneflow/core/autograd/gradient_funcs/where.cpp b/oneflow/core/autograd/gradient_funcs/where.cpp index f31b340d58d..2e2f0f8e0a9 100644 --- a/oneflow/core/autograd/gradient_funcs/where.cpp +++ b/oneflow/core/autograd/gradient_funcs/where.cpp @@ -21,27 +21,27 @@ limitations under the License. namespace oneflow { namespace one { -struct WhereInterpState : public OpExprInterpState { +struct WhereCaptureState : public AutoGradCaptureState { bool requires_grad_x; bool requires_grad_y; }; -struct WhereScalarInterpState : public OpExprInterpState { +struct WhereScalarCaptureState : public AutoGradCaptureState { bool requires_grad; }; -class Where : public OpExprGradFunction { +class Where : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(WhereInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + Maybe Capture(WhereCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const WhereInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const WhereCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; }; Maybe Where::Init(const OpExpr& op) { return Maybe::Ok(); } -Maybe Where::Capture(WhereInterpState* ctx, const TensorTuple& inputs, +Maybe Where::Capture(WhereCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad_x = inputs.at(1)->requires_grad(); ctx->requires_grad_y = inputs.at(2)->requires_grad(); @@ -53,7 +53,7 @@ Maybe Where::Capture(WhereInterpState* ctx, const TensorTuple& inputs, return Maybe::Ok(); } -Maybe Where::Apply(const WhereInterpState* ctx, const TensorTuple& out_grads, +Maybe Where::Apply(const WhereCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const { if ((!ctx->requires_grad_x) && (!ctx->requires_grad_y)) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); @@ -74,10 +74,10 @@ Maybe Where::Apply(const WhereInterpState* ctx, const TensorTuple& out_gra return Maybe::Ok(); } -class WhereScalar : public OpExprGradFunction { +class WhereScalar : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } - Maybe Capture(WhereScalarInterpState* ctx, const TensorTuple& inputs, + Maybe Capture(WhereScalarCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->requires_grad = inputs.at(1)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -90,7 +90,7 @@ class WhereScalar : public OpExprGradFunction { class WhereScalarX : public WhereScalar { public: - Maybe Apply(const WhereScalarInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const WhereScalarCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); @@ -107,7 +107,7 @@ class WhereScalarX : public WhereScalar { class WhereScalarY : public WhereScalar { public: - Maybe Apply(const WhereScalarInterpState* ctx, const TensorTuple& out_grads, + Maybe Apply(const WhereScalarCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { if (!ctx->requires_grad) { return Maybe::Ok(); } CHECK_EQ_OR_RETURN(out_grads.size(), 1); diff --git a/oneflow/core/framework/op_expr.h b/oneflow/core/framework/op_expr.h index 62a982b92f6..8e7d139beda 100644 --- a/oneflow/core/framework/op_expr.h +++ b/oneflow/core/framework/op_expr.h @@ -258,11 +258,11 @@ class SelectFirstOpExpr final : public OpExpr { mutable std::shared_ptr op_grad_func_; }; -class OpExprInterpState; +class AutoGradCaptureState; // TODO(): Finish the class definition of `FunctionOpExpr`. class FunctionOpExpr : public OpExpr { public: - using FType = std::function(const std::shared_ptr& /*ctx*/, + using FType = std::function(const std::shared_ptr& /*ctx*/, const TensorTuple& /*inputs or out_grads*/, TensorTuple* /*outputs or in_grads*/)>; @@ -287,8 +287,8 @@ class FunctionOpExpr : public OpExpr { FType forward() const { return forward_; } FType backward() const { return backward_; } - std::shared_ptr state() const { return state_; } - std::shared_ptr mutable_state() { return state_; } + std::shared_ptr state() const { return state_; } + std::shared_ptr mutable_state() { return state_; } Maybe IsGradDisabled() const override { return false; } Maybe GetOrCreateOpGradClosure() const override { OF_UNIMPLEMENTED(); } @@ -296,7 +296,7 @@ class FunctionOpExpr : public OpExpr { private: FType forward_; FType backward_; - std::shared_ptr state_; + std::shared_ptr state_; }; } // namespace one diff --git a/oneflow/core/framework/op_expr_grad_function.h b/oneflow/core/framework/op_expr_grad_function.h index 70caf033641..b2e381f8409 100644 --- a/oneflow/core/framework/op_expr_grad_function.h +++ b/oneflow/core/framework/op_expr_grad_function.h @@ -18,40 +18,57 @@ limitations under the License. #define ONEFLOW_CORE_FRAMEWORK_OP_EXPR_GRAD_FUNCTION_H_ #include "oneflow/core/common/auto_registration_factory.h" -#include "oneflow/core/framework/op_interpreter.h" // OpExprInterpState +#include "oneflow/core/framework/op_interpreter.h" namespace oneflow { namespace one { static constexpr char kGradientOpSuffix[] = ".grad"; +class AutoGradCaptureState { + public: + AutoGradCaptureState() = default; + virtual ~AutoGradCaptureState() = default; + + const TensorTuple& SavedTensors() const { return saved_tensors_; } + + size_t SaveTensorForBackward(const std::shared_ptr& tensor) { + size_t offset = saved_tensors_.size(); + saved_tensors_.push_back(tensor); + return offset; + } + + private: + TensorTuple saved_tensors_; +}; + // Stateless container base of the backward op exprs. // The backward op exprs should be contained in the derived class. class OpExprGradFunctionIf { public: virtual ~OpExprGradFunctionIf() = default; - virtual std::shared_ptr MakeCustomState() const = 0; + virtual std::shared_ptr MakeCustomState() const = 0; virtual Maybe Init(const OpExpr& op) = 0; // Capture forward inputs and outputs for backward. - virtual Maybe CaptureIf(OpExprInterpState* ctx, const TensorTuple& inputs, + virtual Maybe CaptureIf(AutoGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const OpExprInterpContext& interp_ctx) const = 0; - virtual Maybe ApplyIf(const OpExprInterpState* ctx, const TensorTuple& out_grads, + virtual Maybe ApplyIf(const AutoGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const = 0; }; template class OpExprGradFunction : public OpExprGradFunctionIf { public: - std::shared_ptr MakeCustomState() const override { + std::shared_ptr MakeCustomState() const override { return std::make_shared(); } - Maybe CaptureIf(OpExprInterpState* ctx, const TensorTuple& inputs, + Maybe CaptureIf(AutoGradCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const OpExprInterpContext& interp_ctx) const override { StateT* state = dynamic_cast(ctx); @@ -71,7 +88,7 @@ class OpExprGradFunction : public OpExprGradFunctionIf { return Capture(state, detach_inputs, detach_outputs, interp_ctx); } - Maybe ApplyIf(const OpExprInterpState* ctx, const TensorTuple& out_grads, + Maybe ApplyIf(const AutoGradCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const StateT* state = dynamic_cast(ctx); CHECK_NOTNULL_OR_RETURN(state); @@ -104,7 +121,7 @@ class OpExprGradClosure { explicit OpExprGradClosure(const std::shared_ptr& impl) : impl_(impl), state_(impl->MakeCustomState()) {} explicit OpExprGradClosure(const std::shared_ptr& impl, - const std::shared_ptr& state) + const std::shared_ptr& state) : impl_(impl), state_(state) {} virtual ~OpExprGradClosure() = default; @@ -120,7 +137,7 @@ class OpExprGradClosure { private: std::shared_ptr impl_; - std::shared_ptr state_; + std::shared_ptr state_; }; #define REGISTER_OP_EXPR_GRAD_FUNCTION(op_type, op_grad) \ diff --git a/oneflow/core/framework/op_interpreter.h b/oneflow/core/framework/op_interpreter.h index b44129a2b6b..71ea6b844d4 100644 --- a/oneflow/core/framework/op_interpreter.h +++ b/oneflow/core/framework/op_interpreter.h @@ -33,23 +33,6 @@ class NdSbp; namespace one { -class OpExprInterpState { - public: - OpExprInterpState() = default; - virtual ~OpExprInterpState() = default; - - const TensorTuple& SavedTensors() const { return saved_tensors_; } - - size_t SaveTensorForBackward(const std::shared_ptr& tensor) { - size_t offset = saved_tensors_.size(); - saved_tensors_.push_back(tensor); - return offset; - } - - private: - TensorTuple saved_tensors_; -}; - struct OpExprInterpContext { OpExprInterpContext(const AttrMap& attrs_arg) : attrs(attrs_arg) {} OpExprInterpContext(const AttrMap& attrs_arg, Symbol device_arg)