Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename OpExprInterpState to AutoGradCaptureState #5918

Merged
merged 14 commits into from
Aug 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 31 additions & 30 deletions oneflow/core/autograd/gradient_funcs/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ limitations under the License.
namespace oneflow {
namespace one {

struct BaseActivationInterpState : public OpExprInterpState {
struct BaseActivationCaptureState : public AutoGradCaptureState {
bool requires_grad;
};

class BaseActivation : public OpExprGradFunction<BaseActivationInterpState> {
class BaseActivation : public OpExprGradFunction<BaseActivationCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Maybe<void> Capture(BaseActivationInterpState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(BaseActivationCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
Expand All @@ -39,7 +39,7 @@ class BaseActivation : public OpExprGradFunction<BaseActivationInterpState> {

class Silu : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -53,7 +53,7 @@ class Silu : public BaseActivation {

class Mish : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -67,7 +67,7 @@ class Mish : public BaseActivation {

class Selu : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -81,7 +81,7 @@ class Selu : public BaseActivation {

class Softsign : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -95,7 +95,7 @@ class Softsign : public BaseActivation {

class GeLU : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -109,7 +109,7 @@ class GeLU : public BaseActivation {

class HardSigmoid : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -123,7 +123,7 @@ class HardSigmoid : public BaseActivation {

class HardSwish : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -136,15 +136,15 @@ class HardSwish : public BaseActivation {
};

// ===== Activation with parms ====
struct ReLUInterpState : public OpExprInterpState {
struct ReLUCaptureState : public AutoGradCaptureState {
bool requires_grad;
};

class ReLU : public OpExprGradFunction<ReLUInterpState> {
class ReLU : public OpExprGradFunction<ReLUCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Maybe<void> Capture(ReLUInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
Maybe<void> Capture(ReLUCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
Expand All @@ -153,7 +153,7 @@ class ReLU : public OpExprGradFunction<ReLUInterpState> {
return Maybe<void>::Ok();
}

Maybe<void> Apply(const ReLUInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const ReLUCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -165,12 +165,13 @@ class ReLU : public OpExprGradFunction<ReLUInterpState> {
}
};

struct LeakyReluInterpState : public OpExprInterpState {
// ===== Activation with parms ====
struct LeakyReluCaptureState : public AutoGradCaptureState {
bool requires_grad;
float alpha;
};

class LeakyRelu : public OpExprGradFunction<LeakyReluInterpState> {
class LeakyRelu : public OpExprGradFunction<LeakyReluCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
Expand All @@ -179,7 +180,7 @@ class LeakyRelu : public OpExprGradFunction<LeakyReluInterpState> {
return Maybe<void>::Ok();
}

Maybe<void> Capture(LeakyReluInterpState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(LeakyReluCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
ctx->requires_grad = inputs.at(0)->requires_grad();
Expand All @@ -191,7 +192,7 @@ class LeakyRelu : public OpExprGradFunction<LeakyReluInterpState> {
return Maybe<void>::Ok();
}

Maybe<void> Apply(const LeakyReluInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const LeakyReluCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -206,13 +207,13 @@ class LeakyRelu : public OpExprGradFunction<LeakyReluInterpState> {
AttrMap base_attrs_;
};

struct HardTanhInterpState : public OpExprInterpState {
struct HardTanhCaptureState : public AutoGradCaptureState {
bool requires_grad;
double min_val;
double max_val;
};

class HardTanh : public OpExprGradFunction<HardTanhInterpState> {
class HardTanh : public OpExprGradFunction<HardTanhCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
Expand All @@ -221,7 +222,7 @@ class HardTanh : public OpExprGradFunction<HardTanhInterpState> {
return Maybe<void>::Ok();
}

Maybe<void> Capture(HardTanhInterpState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(HardTanhCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(outputs.size(), 1);
ctx->requires_grad = inputs.at(0)->requires_grad();
Expand All @@ -234,7 +235,7 @@ class HardTanh : public OpExprGradFunction<HardTanhInterpState> {
return Maybe<void>::Ok();
}

Maybe<void> Apply(const HardTanhInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const HardTanhCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -250,12 +251,12 @@ class HardTanh : public OpExprGradFunction<HardTanhInterpState> {
AttrMap base_attrs_;
};

struct EluInterpState : public OpExprInterpState {
struct EluCaptureState : public AutoGradCaptureState {
bool requires_grad;
double alpha;
};

class Elu : public OpExprGradFunction<EluInterpState> {
class Elu : public OpExprGradFunction<EluCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
Expand All @@ -264,7 +265,7 @@ class Elu : public OpExprGradFunction<EluInterpState> {
return Maybe<void>::Ok();
}

Maybe<void> Capture(EluInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
Maybe<void> Capture(EluCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
ctx->requires_grad = inputs.at(0)->requires_grad();
Expand All @@ -276,7 +277,7 @@ class Elu : public OpExprGradFunction<EluInterpState> {
return Maybe<void>::Ok();
}

Maybe<void> Apply(const EluInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const EluCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
Expand All @@ -291,16 +292,16 @@ class Elu : public OpExprGradFunction<EluInterpState> {
AttrMap base_attrs_;
};

struct PReLUInterpState : public OpExprInterpState {
struct PReLUCaptureState : public AutoGradCaptureState {
bool input_requires_grad;
bool alpha_requires_grad;
};

class PReLU : public OpExprGradFunction<PReLUInterpState> {
class PReLU : public OpExprGradFunction<PReLUCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Maybe<void> Capture(PReLUInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
Maybe<void> Capture(PReLUCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2);
ctx->input_requires_grad = inputs.at(0)->requires_grad(); // input
Expand All @@ -311,7 +312,7 @@ class PReLU : public OpExprGradFunction<PReLUInterpState> {
return Maybe<void>::Ok();
}

Maybe<void> Apply(const PReLUInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const PReLUCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
const auto& dy = out_grads.at(0);
Expand Down
14 changes: 7 additions & 7 deletions oneflow/core/autograd/gradient_funcs/adaptive_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@ limitations under the License.
namespace oneflow {
namespace one {

struct AdaptivePoolInterpState : public OpExprInterpState {
struct AdaptivePoolCaptureState : public AutoGradCaptureState {
bool requires_grad;
};

class AdaptivePoolNdGrad : public OpExprGradFunction<AdaptivePoolInterpState> {
class AdaptivePoolNdGrad : public OpExprGradFunction<AdaptivePoolCaptureState> {
public:
using OpExprGradFunction<AdaptivePoolInterpState>::Init;
using OpExprGradFunction<AdaptivePoolCaptureState>::Init;

Maybe<void> Init(const OpExpr& op, std::string mode, const int& ndims);
Maybe<void> Capture(AdaptivePoolInterpState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(AdaptivePoolCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const AdaptivePoolInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const AdaptivePoolCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

private:
Expand All @@ -52,7 +52,7 @@ Maybe<void> AdaptivePoolNdGrad::Init(const OpExpr& op, std::string mode, const i
return Maybe<void>::Ok();
}

Maybe<void> AdaptivePoolNdGrad::Capture(AdaptivePoolInterpState* ctx, const TensorTuple& inputs,
Maybe<void> AdaptivePoolNdGrad::Capture(AdaptivePoolCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
Expand All @@ -61,7 +61,7 @@ Maybe<void> AdaptivePoolNdGrad::Capture(AdaptivePoolInterpState* ctx, const Tens
return Maybe<void>::Ok();
}

Maybe<void> AdaptivePoolNdGrad::Apply(const AdaptivePoolInterpState* ctx,
Maybe<void> AdaptivePoolNdGrad::Apply(const AdaptivePoolCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
Expand Down
8 changes: 4 additions & 4 deletions oneflow/core/autograd/gradient_funcs/add_n.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ limitations under the License.
namespace oneflow {
namespace one {

struct AddNInterpState : public OpExprInterpState {
struct AddNCaptureState : public AutoGradCaptureState {
int32_t input_num;
std::vector<bool> requires_grad;
};

class AddN : public OpExprGradFunction<AddNInterpState> {
class AddN : public OpExprGradFunction<AddNCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Maybe<void> Capture(AddNInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
Maybe<void> Capture(AddNCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
ctx->input_num = inputs.size();
ctx->requires_grad.resize(inputs.size());
Expand All @@ -37,7 +37,7 @@ class AddN : public OpExprGradFunction<AddNInterpState> {
return Maybe<void>::Ok();
}

Maybe<void> Apply(const AddNInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const AddNCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(ctx->input_num);
Expand Down
12 changes: 6 additions & 6 deletions oneflow/core/autograd/gradient_funcs/avg_pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace one {

namespace {

struct AvgPoolingInterpState : public OpExprInterpState {
struct AvgPoolingCaptureState : public AutoGradCaptureState {
bool requires_grad;
size_t input_index;
size_t output_index;
Expand All @@ -40,13 +40,13 @@ struct AvgPoolingInterpState : public OpExprInterpState {
int64_t divisor_override;
};

class AvgPoolingNdGrad : public OpExprGradFunction<AvgPoolingInterpState> {
class AvgPoolingNdGrad : public OpExprGradFunction<AvgPoolingCaptureState> {
public:
virtual ~AvgPoolingNdGrad() = default;
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(AvgPoolingInterpState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(AvgPoolingCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const AvgPoolingInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const AvgPoolingCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

private:
Expand All @@ -60,7 +60,7 @@ Maybe<void> AvgPoolingNdGrad::Init(const OpExpr& op) {
return Maybe<void>::Ok();
}

Maybe<void> AvgPoolingNdGrad::Capture(AvgPoolingInterpState* ctx, const TensorTuple& inputs,
Maybe<void> AvgPoolingNdGrad::Capture(AvgPoolingCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
Expand All @@ -80,7 +80,7 @@ Maybe<void> AvgPoolingNdGrad::Capture(AvgPoolingInterpState* ctx, const TensorTu
return Maybe<void>::Ok();
}

Maybe<void> AvgPoolingNdGrad::Apply(const AvgPoolingInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> AvgPoolingNdGrad::Apply(const AvgPoolingCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
Expand Down
12 changes: 6 additions & 6 deletions oneflow/core/autograd/gradient_funcs/batch_gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ limitations under the License.
namespace oneflow {
namespace one {

struct BatchGatherInterpState : public OpExprInterpState {
struct BatchGatherCaptureState : public AutoGradCaptureState {
int64_t num_segments;
bool requires_grad;
};

class BatchGather : public OpExprGradFunction<BatchGatherInterpState> {
class BatchGather : public OpExprGradFunction<BatchGatherCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(BatchGatherInterpState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(BatchGatherCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const BatchGatherInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const BatchGatherCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

private:
Expand All @@ -48,7 +48,7 @@ Maybe<void> BatchGather::Init(const OpExpr& op) {
return Maybe<void>::Ok();
}

Maybe<void> BatchGather::Capture(BatchGatherInterpState* ctx, const TensorTuple& inputs,
Maybe<void> BatchGather::Capture(BatchGatherCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
Expand All @@ -59,7 +59,7 @@ Maybe<void> BatchGather::Capture(BatchGatherInterpState* ctx, const TensorTuple&
return Maybe<void>::Ok();
}

Maybe<void> BatchGather::Apply(const BatchGatherInterpState* ctx, const TensorTuple& out_grads,
Maybe<void> BatchGather::Apply(const BatchGatherCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
in_grads->resize(2);
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
Expand Down
Loading