diff --git a/oneflow/api/python/autograd/autograd.cpp b/oneflow/api/python/autograd/autograd.cpp index af6467b615e..89c7ec18128 100644 --- a/oneflow/api/python/autograd/autograd.cpp +++ b/oneflow/api/python/autograd/autograd.cpp @@ -71,7 +71,7 @@ Maybe Backward(const one::TensorTuple& outputs, const one::Ten bool retain_graph, bool create_graph) { if (create_graph) { retain_graph = true; } std::shared_ptr gradients = JUST(CheckAndInitOutGrads(outputs, out_grads)); - JUST(one::GetThreadLocalAutogradEngine()->RunBackwardAndSaveGrads4LeafTensor( + JUST(one::GetThreadLocalAutogradEngine()->RunBackwardAndSaveGrads4LeafTensorIf( outputs, *gradients, retain_graph, create_graph)); return std::make_shared(0); } @@ -86,7 +86,7 @@ Maybe Grad(const one::TensorTuple& outputs, const one::TensorT [](const std::shared_ptr& tensor) { return tensor->requires_grad(); })) << "All input tensors `.requires_grad` should be true"; std::shared_ptr gradients = JUST(CheckAndInitOutGrads(outputs, out_grads)); - return one::GetThreadLocalAutogradEngine()->RunBackwardAndReturnInputsTensorGrad( + return one::GetThreadLocalAutogradEngine()->RunBackwardAndReturnInputsTensorGradIf( outputs, inputs, *gradients, retain_graph, create_graph); } diff --git a/oneflow/core/autograd/autograd_engine.cpp b/oneflow/core/autograd/autograd_engine.cpp index b897571c94b..f1a43ca6046 100644 --- a/oneflow/core/autograd/autograd_engine.cpp +++ b/oneflow/core/autograd/autograd_engine.cpp @@ -74,25 +74,25 @@ Maybe CheckConsistentTensorsMeta(const TensorTuple& tensor_tuple) { } // namespace -Maybe AutogradEngine::RunBackwardAndSaveGrads4LeafTensor(const TensorTuple& outputs, - const TensorTuple& out_grads, - bool retain_graph, - bool create_graph) { +Maybe AutogradEngine::RunBackwardAndSaveGrads4LeafTensorIf(const TensorTuple& outputs, + const TensorTuple& out_grads, + bool retain_graph, + bool create_graph) { JUST(CheckConsistentTensorsMeta(outputs)); JUST(CheckConsistentTensorsMeta(out_grads)); DisableCheckConsistentTensorMetaScope disable_meta_check; - return RunBackwardAndSaveGrads4LeafTensorIf(outputs, out_grads, retain_graph, create_graph); + return RunBackwardAndSaveGrads4LeafTensor(outputs, out_grads, retain_graph, create_graph); } -Maybe AutogradEngine::RunBackwardAndReturnInputsTensorGrad( +Maybe AutogradEngine::RunBackwardAndReturnInputsTensorGradIf( const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads, bool retain_graph, bool create_graph) { JUST(CheckConsistentTensorsMeta(outputs)); JUST(CheckConsistentTensorsMeta(inputs)); JUST(CheckConsistentTensorsMeta(out_grads)); DisableCheckConsistentTensorMetaScope disable_meta_check; - return RunBackwardAndReturnInputsTensorGradIf(outputs, inputs, out_grads, retain_graph, - create_graph); + return RunBackwardAndReturnInputsTensorGrad(outputs, inputs, out_grads, retain_graph, + create_graph); } StackFunctionNode::StackFunctionNode( @@ -190,10 +190,10 @@ void StackAutogradEngine::ClearReleasedFunctionNodes() { node_list_.end()); } -Maybe StackAutogradEngine::RunBackwardAndSaveGrads4LeafTensorIf(const TensorTuple& outputs, - const TensorTuple& out_grads, - bool retain_graph, - bool create_graph) { +Maybe StackAutogradEngine::RunBackwardAndSaveGrads4LeafTensor(const TensorTuple& outputs, + const TensorTuple& out_grads, + bool retain_graph, + bool create_graph) { ClearReleasedFunctionNodes(); for (int i = 0; i < outputs.size(); ++i) { JUST(JUST(outputs.at(i)->current_grad())->PushPartialTensor(out_grads.at(i))); @@ -213,7 +213,7 @@ Maybe StackAutogradEngine::RunBackwardAndSaveGrads4LeafTensorIf(const Tens return Maybe::Ok(); } -Maybe StackAutogradEngine::RunBackwardAndReturnInputsTensorGradIf( +Maybe StackAutogradEngine::RunBackwardAndReturnInputsTensorGrad( const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads, bool retain_graph, bool create_graph) { ClearReleasedFunctionNodes(); @@ -419,10 +419,10 @@ Maybe GraphTask::Apply(bool save_grad_for_leaf) { return Maybe::Ok(); } -Maybe GraphAutogradEngine::RunBackwardAndSaveGrads4LeafTensorIf(const TensorTuple& outputs, - const TensorTuple& out_grads, - bool retain_graph, - bool create_graph) { +Maybe GraphAutogradEngine::RunBackwardAndSaveGrads4LeafTensor(const TensorTuple& outputs, + const TensorTuple& out_grads, + bool retain_graph, + bool create_graph) { for (int i = 0; i < outputs.size(); ++i) { JUST(JUST(outputs.at(i)->current_grad())->PushPartialTensor(out_grads.at(i))); } @@ -432,7 +432,7 @@ Maybe GraphAutogradEngine::RunBackwardAndSaveGrads4LeafTensorIf(const Tens return Maybe::Ok(); } -Maybe GraphAutogradEngine::RunBackwardAndReturnInputsTensorGradIf( +Maybe GraphAutogradEngine::RunBackwardAndReturnInputsTensorGrad( const TensorTuple& outputs, const TensorTuple& inputs, const TensorTuple& out_grads, bool retain_graph, bool create_graph) { std::shared_ptr input_current_grad = std::make_shared(inputs.size()); diff --git a/oneflow/core/autograd/autograd_engine.h b/oneflow/core/autograd/autograd_engine.h index 5fe1230f0e7..f3f05cafec0 100644 --- a/oneflow/core/autograd/autograd_engine.h +++ b/oneflow/core/autograd/autograd_engine.h @@ -69,13 +69,13 @@ class AutogradEngine { public: virtual ~AutogradEngine() = default; - Maybe RunBackwardAndSaveGrads4LeafTensor(const TensorTuple& outputs, - const TensorTuple& out_grads, bool retain_graph, - bool create_graph); - Maybe RunBackwardAndReturnInputsTensorGrad(const TensorTuple& outputs, - const TensorTuple& inputs, - const TensorTuple& out_grads, - bool retain_graph, bool create_graph); + Maybe RunBackwardAndSaveGrads4LeafTensorIf(const TensorTuple& outputs, + const TensorTuple& out_grads, bool retain_graph, + bool create_graph); + Maybe RunBackwardAndReturnInputsTensorGradIf(const TensorTuple& outputs, + const TensorTuple& inputs, + const TensorTuple& out_grads, + bool retain_graph, bool create_graph); virtual void ClearEngine() = 0; // Builds FunctionNode, binding to all `outputs_` tensors and saving in AutogradEngine virtual Maybe AddBackwardFuncPtr( @@ -88,15 +88,14 @@ class AutogradEngine { AutogradEngine() = default; private: - virtual Maybe RunBackwardAndSaveGrads4LeafTensorIf(const TensorTuple& outputs, - const TensorTuple& out_grads, - bool retain_graph, - bool create_graph) = 0; - virtual Maybe RunBackwardAndReturnInputsTensorGradIf(const TensorTuple& outputs, - const TensorTuple& inputs, - const TensorTuple& out_grads, - bool retain_graph, - bool create_graph) = 0; + virtual Maybe RunBackwardAndSaveGrads4LeafTensor(const TensorTuple& outputs, + const TensorTuple& out_grads, + bool retain_graph, bool create_graph) = 0; + virtual Maybe RunBackwardAndReturnInputsTensorGrad(const TensorTuple& outputs, + const TensorTuple& inputs, + const TensorTuple& out_grads, + bool retain_graph, + bool create_graph) = 0; }; // Stack Autograd Node and Engine @@ -137,14 +136,14 @@ class StackAutogradEngine final : public AutogradEngine { // moment. std::list> node_list_; void ClearReleasedFunctionNodes(); - Maybe RunBackwardAndSaveGrads4LeafTensorIf(const TensorTuple& outputs, - const TensorTuple& out_grads, bool retain_graph, - bool create_graph) override; - Maybe RunBackwardAndReturnInputsTensorGradIf(const TensorTuple& outputs, - const TensorTuple& inputs, - const TensorTuple& out_grads, - bool retain_graph, - bool create_graph) override; + Maybe RunBackwardAndSaveGrads4LeafTensor(const TensorTuple& outputs, + const TensorTuple& out_grads, bool retain_graph, + bool create_graph) override; + Maybe RunBackwardAndReturnInputsTensorGrad(const TensorTuple& outputs, + const TensorTuple& inputs, + const TensorTuple& out_grads, + bool retain_graph, + bool create_graph) override; }; // Graph Autograd Node and Engine @@ -194,14 +193,14 @@ class GraphAutogradEngine final : public AutogradEngine { const TensorTuple& inputs, TensorTuple* outputs) override; private: - Maybe RunBackwardAndSaveGrads4LeafTensorIf(const TensorTuple& outputs, - const TensorTuple& out_grads, bool retain_graph, - bool create_graph) override; - Maybe RunBackwardAndReturnInputsTensorGradIf(const TensorTuple& outputs, - const TensorTuple& inputs, - const TensorTuple& out_grads, - bool retain_graph, - bool create_graph) override; + Maybe RunBackwardAndSaveGrads4LeafTensor(const TensorTuple& outputs, + const TensorTuple& out_grads, bool retain_graph, + bool create_graph) override; + Maybe RunBackwardAndReturnInputsTensorGrad(const TensorTuple& outputs, + const TensorTuple& inputs, + const TensorTuple& out_grads, + bool retain_graph, + bool create_graph) override; }; AutogradEngine* GetThreadLocalAutogradEngine();