Skip to content

Commit

Permalink
[DoubleGrad PR #5] Enabled gradient computations for grad_tensors pas…
Browse files Browse the repository at this point in the history
…sed to paddle.grad() (PaddlePaddle#41198)

* [Refactor] refactored eager_gen.py PR #2

* [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes

* Fixed minor issue

* Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition

* Fixed issues

* Supported higher-order grad node generation

* [DoubleGrad PR #4] Supported higher-order GradNode generation

* [DoubleGrad #4] Bug Fixes to Double Grad Node Generation

* Fixed yaml typo

* Fixed yaml typo

* fixed minor issues

* [DoubleGrad PR #5] Enabled gradient computations for grad_tensors passed to paddle.grad()

* Fixed minor issue

* Fixed CI-Inference issue

* Fixed CI-inference issues
  • Loading branch information
jim19930609 committed Apr 2, 2022
1 parent 56f108f commit afadb8c
Show file tree
Hide file tree
Showing 14 changed files with 124 additions and 70 deletions.
10 changes: 7 additions & 3 deletions paddle/fluid/eager/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@ add_subdirectory(accumulation)
add_subdirectory(custom_operator)
if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
add_subdirectory(pylayer)
cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulator)
add_dependencies(grad_tensor_holder eager_final_state_codegen)
cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info)
endif()

cc_library(grad_node_info SRCS grad_node_info.cc DEPS phi_api phi_tensor)
cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulator)

cc_library(autograd_meta SRCS autograd_meta.cc DEPS phi_api phi_tensor)
cc_library(utils SRCS utils.cc DEPS phi_api phi_tensor global_utils layer proto_desc operator op_registry variable_helper memcpy scale_op autograd_meta hook_utils)
cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info)

add_subdirectory(tests)
if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
add_subdirectory(tests)
endif()
1 change: 1 addition & 0 deletions paddle/fluid/eager/api/utils/hook_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ void RetainGradForTensor(const paddle::experimental::Tensor& tensor) {
VLOG(7) << "Set impl for RetainGrad Hook for tensor: " << t.name();
// Simply Copy impl() to grad_tensor
grad_tensor->set_impl(t.impl());
grad_tensor->set_autograd_meta(t.mutable_autograd_meta());
return *grad_tensor.get();
} else {
VLOG(7) << "Retain NULL paddle::experimental::Tensor in Grad Hook";
Expand Down
17 changes: 6 additions & 11 deletions paddle/fluid/eager/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
continue;
}

// TODO(zhanlve): Copy and Modify GradNode if is_general_grad
GradNodeBase* grad_node = shared_grad_node.get();

// Prepare GradTensorHolder
Expand All @@ -486,16 +487,9 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// Feed given tensor if it's provided
VLOG(6) << "Fill grad input tensor " << i << "with give grad tensor";

if (grad_tensors[i].is_initialized()) {
// Deep copy
paddle::experimental::Tensor tmp_tensor;
tmp_tensor.copy_(grad_tensors[i], grad_tensors[i].inner_place(), false);
node_input_buffers_dict[grad_node]->add(input_info.first,
input_info.second, tmp_tensor);
} else {
node_input_buffers_dict[grad_node]->add(
input_info.first, input_info.second, grad_tensors[i]);
}
// Deep copy
node_input_buffers_dict[grad_node]->CopyValueFromTensor(
input_info.first, input_info.second, grad_tensors[i]);

} else {
VLOG(6) << "Fill grad input tensor " << i << " with 1.0";
Expand All @@ -504,7 +498,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// dims
// GradTensorHolder will initialize another tensor with same tensortype,
// datatype and dims but filled with 1.0
node_input_buffers_dict[grad_node]->add(
node_input_buffers_dict[grad_node]->CopyValueFromTensor(
input_info.first, input_info.second, tensor, true /*fill_one=true*/);
}

Expand Down Expand Up @@ -686,6 +680,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
}
}
}

if (!is_general_grad) return {};
return GeneralGrad::Instance().GetResults(inputs, allow_unused, create_graph);
}
Expand Down
118 changes: 79 additions & 39 deletions paddle/fluid/eager/grad_tensor_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/fluid/eager/grad_tensor_holder.h"
#include "paddle/fluid/imperative/gradient_accumulator.h"

#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
Expand All @@ -26,9 +27,9 @@ void GradTensorHolder::SetBufferSlotRankZeros(size_t slot_id, size_t rank) {
paddle::experimental::zeros_like(buffer_[slot_id][rank]);
}

void GradTensorHolder::add(size_t slot_id, size_t rank,
const paddle::experimental::Tensor& t,
bool fill_one) {
void GradTensorHolder::CopyValueFromTensor(
size_t slot_id, size_t rank, const paddle::experimental::Tensor& t,
bool fill_one) {
// TODO(jiabin): We need to deal with empty input_buffer with slot size not
// empty;
PADDLE_ENFORCE(slot_id < buffer_.size(),
Expand All @@ -50,44 +51,15 @@ void GradTensorHolder::add(size_t slot_id, size_t rank,
slot_id, buffer_[slot_id].size(), rank));
if (!fill_one) {
paddle::experimental::Tensor& buffer_tensor = buffer_[slot_id][rank];
// TODO(jiabin): Code bellow is ugly to divide which inner var we used,
// remove framework::Variable
// related code later.
// This if statement is trying to test neither phi::Tensor nor
// framework::Variable is initialized.
if ((!buffer_tensor.defined() || !buffer_tensor.initialized())) {
// Simply copy tensor->impl
buffer_tensor = t;
// Perform deep copy here
buffer_tensor.copy_(t, t.inner_place(), false);
buffer_tensor.set_autograd_meta(t.mutable_autograd_meta());

} else {
// Accumulation
PADDLE_ENFORCE_EQ(t.initialized(), true,
paddle::platform::errors::Fatal(
"We can only accumulate initialized tensor, but we "
"got tensor: %s is empty please check you network "
"and make sure it creates grads.",
t.name()));
if (t.is_dense_tensor()) {
if (buffer_tensor.is_dense_tensor()) {
paddle::imperative::TensorAdd<paddle::experimental::Tensor>(
t, &buffer_tensor);
} else {
// TODO(jiabin): Support Other TensorBase later
paddle::experimental::Tensor new_buffer(
std::make_shared<phi::DenseTensor>(), "tmp_accumulator");
paddle::imperative::SelectedRowsAddTensor(buffer_tensor, t,
&new_buffer);
buffer_tensor.set_impl(new_buffer.impl());
}
} else {
// TODO(jiabin): Support Other TensorBase later
if (buffer_tensor.is_dense_tensor()) {
paddle::imperative::SelectedRowsAddToTensor(t, &buffer_tensor);
} else {
buffer_tensor =
std::move(*paddle::imperative::SelectedRowsMerge<
paddle::experimental::Tensor>(t, buffer_tensor));
}
}
PADDLE_THROW(paddle::platform::errors::Fatal(
"Cannot copy grad_tensors' value to grad tensor holders,"
"input buffer has already been initialized."));
}
} else {
// Create new tensor->impl and fill it with 1.0
Expand All @@ -98,4 +70,72 @@ void GradTensorHolder::add(size_t slot_id, size_t rank,
}
}

void GradTensorHolder::add(size_t slot_id, size_t rank,
const paddle::experimental::Tensor& t) {
// TODO(jiabin): We need to deal with empty input_buffer with slot size not
// empty;
PADDLE_ENFORCE(slot_id < buffer_.size(),
paddle::platform::errors::Fatal(
"Invalid slot_id for GradTensorHolder::add() "
"which exceeds size of buffer"));
VLOG(6) << "Add Tensor for buffer_ slot: " << slot_id
<< ", size: " << buffer_[slot_id].size();
if (buffer_[slot_id].empty()) {
VLOG(6) << "Pass add Tensor for buffer_ slot: " << slot_id
<< " since its buffer_ is empty ";
return;
}
PADDLE_ENFORCE(
rank < buffer_[slot_id].size(),
paddle::platform::errors::Fatal(
"Invalid rank for GradTensorHolder::add() which exceeds size "
"of buffer slot %d, got slot size is: %d rank is: %d",
slot_id, buffer_[slot_id].size(), rank));

paddle::experimental::Tensor& buffer_tensor = buffer_[slot_id][rank];
// TODO(jiabin): Code bellow is ugly to divide which inner var we used,
// remove framework::Variable
// related code later.
// This if statement is trying to test neither phi::Tensor nor
// framework::Variable is initialized.
if ((!buffer_tensor.defined() || !buffer_tensor.initialized())) {
// Simply copy tensor->impl
buffer_tensor = t;
} else {
// Accumulation
PADDLE_ENFORCE_EQ(t.initialized(), true,
paddle::platform::errors::Fatal(
"We can only accumulate initialized tensor, but we "
"got tensor: %s is empty please check you network "
"and make sure it creates grads.",
t.name()));
if (t.is_dense_tensor()) {
if (buffer_tensor.is_dense_tensor()) {
buffer_tensor = add_final_state_dygraph_function(t, buffer_tensor);

} else {
// TODO(jiabin): Support Other TensorBase later
// TODO(zhanlve): Replace SelectedRowsAddTensor with
// add_dygraph_function once it's supported
paddle::experimental::Tensor new_buffer(
std::make_shared<phi::DenseTensor>(), "tmp_accumulator");
paddle::imperative::SelectedRowsAddTensor(buffer_tensor, t,
&new_buffer);
buffer_tensor.set_impl(new_buffer.impl());
}
} else {
// TODO(jiabin): Support Other TensorBase later
// TODO(zhanlve): Replace SelectedRowsAddTensor with add_dygraph_function
// once it's supported
if (buffer_tensor.is_dense_tensor()) {
paddle::imperative::SelectedRowsAddToTensor(t, &buffer_tensor);
} else {
buffer_tensor =
std::move(*paddle::imperative::SelectedRowsMerge<
paddle::experimental::Tensor>(t, buffer_tensor));
}
}
}
}

} // namespace egr
6 changes: 4 additions & 2 deletions paddle/fluid/eager/grad_tensor_holder.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ class GradTensorHolder {
GradTensorHolder& operator=(const GradTensorHolder& other) = default;

// Create new tensor and copy tensor->impl
void add(size_t slot_id, size_t rank, const paddle::experimental::Tensor& t,
bool fill_one = false);
void add(size_t slot_id, size_t rank, const paddle::experimental::Tensor& t);
void CopyValueFromTensor(size_t slot_id, size_t rank,
const paddle::experimental::Tensor& t,
bool fill_one = false);

const std::vector<paddle::experimental::Tensor>& operator[](
const size_t& pos) {
Expand Down
5 changes: 1 addition & 4 deletions paddle/fluid/eager/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
add_subdirectory(data_structure_tests)
add_subdirectory(task_tests)

if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
add_subdirectory(performance_tests)
endif()
add_subdirectory(performance_tests)
5 changes: 4 additions & 1 deletion paddle/fluid/eager/tests/data_structure_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
cc_test(test_egr_ds_eager_tensor SRCS eager_tensor_test.cc DEPS ${eager_deps})
cc_test(test_egr_ds_auotgrad_meta SRCS autograd_meta_test.cc DEPS ${eager_deps})
cc_test(test_egr_ds_grad_node_info SRCS grad_node_info_test.cc DEPS ${eager_deps})
cc_test(test_egr_ds_grad_tensor_holder SRCS grad_tensor_holder_test.cc DEPS ${eager_deps})
cc_test(test_egr_ds_accumulation_node SRCS accumulation_node_test.cc DEPS ${eager_deps})
cc_test(test_egr_ds_tensor_wrapper SRCS tensor_wrapper_test.cc DEPS ${eager_deps})

if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
cc_test(test_egr_ds_grad_tensor_holder SRCS grad_tensor_holder_test.cc DEPS ${eager_deps} ${generated_deps})
endif()
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "paddle/phi/core/kernel_registry.h"

PD_DECLARE_KERNEL(full_like, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);

// TODO(jiabin): remove nolint here!!!
using namespace egr; // NOLINT
Expand Down Expand Up @@ -77,11 +78,11 @@ TEST(GradTensorHolder, Interfaces) {

// add():
// fill one
grad_tensor_holder.add(0, 0, et0, true);
grad_tensor_holder.CopyValueFromTensor(0, 0, et0, true);

// accumulation
grad_tensor_holder.add(1, 0, et0, false);
grad_tensor_holder.add(1, 0, et1, false);
grad_tensor_holder.add(1, 0, et0);
grad_tensor_holder.add(1, 0, et1);

// Buffers()
const auto& buffers = grad_tensor_holder.Buffers();
Expand Down Expand Up @@ -141,8 +142,8 @@ TEST(GradTensorHolder, SelectedRowsMergeAdd) {
GradTensorHolder({slot_meta, slot_meta});

// accumulation
grad_tensor_holder.add(0, 0, t1, false);
grad_tensor_holder.add(0, 0, t2, false);
grad_tensor_holder.add(0, 0, t1);
grad_tensor_holder.add(0, 0, t2);

// Buffers()
const auto& buffers = grad_tensor_holder.Buffers();
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/eager/tests/task_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
cc_test(test_egr_task_tensor_utils SRCS tensor_utils_test.cc DEPS ${eager_deps})
cc_test(test_egr_task_eager_utils SRCS eager_utils_test.cc DEPS ${eager_deps})
cc_test(test_egr_task_forward_autograd SRCS forward_autograd_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node)
cc_test(test_egr_task_backward SRCS backward_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node)
cc_test(test_egr_task_hook SRCS hook_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node)
cc_test(test_egr_task_cross_batch SRCS cross_batch_accumulation_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node)
cc_test(test_egr_task_fwd_bwd_joint SRCS fwd_bwd_joint_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node)
cc_test(test_egr_task_grad SRCS grad_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node)

if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
cc_test(test_egr_task_hook SRCS hook_test.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps} eager_scale scale_node)
cc_test(test_egr_task_backward SRCS backward_test.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps} eager_scale scale_node)
cc_test(test_egr_task_grad SRCS grad_test.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps} eager_scale scale_node)
cc_test(test_egr_task_fwd_bwd_joint SRCS fwd_bwd_joint_test.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps} eager_scale scale_node)
cc_test(test_egr_task_cross_batch SRCS cross_batch_accumulation_test.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps} eager_scale scale_node)
cc_test(test_egr_task_hook_intermidiate SRCS hook_test_intermidiate.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps} dygraph_node)
cc_test(test_egr_task_autocodegen SRCS generated_test.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps})
endif()
1 change: 1 addition & 0 deletions paddle/fluid/eager/tests/task_tests/backward_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);

namespace egr {

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/eager/tests/task_tests/fwd_bwd_joint_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
#include "paddle/phi/core/kernel_registry.h"

PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, GPU, ALL_LAYOUT);
#endif

namespace egr {
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/eager/tests/task_tests/grad_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);

namespace egr {

TEST(Grad, SingleNodeEmptyGrad) {
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/include/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ class PADDLE_API Tensor final {
* @return AbstractAutogradMeta*
*/
AbstractAutogradMeta* get_autograd_meta() const;
const std::shared_ptr<AbstractAutogradMeta>& mutable_autograd_meta() const;

/**
* @brief Set the autograd meta object
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/api/lib/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,11 @@ AbstractAutogradMeta *Tensor::get_autograd_meta() const {
return autograd_meta_.get();
}

const std::shared_ptr<AbstractAutogradMeta> &Tensor::mutable_autograd_meta()
const {
return autograd_meta_;
}

void Tensor::set_autograd_meta(
std::shared_ptr<AbstractAutogradMeta> autograd_meta) {
autograd_meta_ = std::move(autograd_meta);
Expand Down

0 comments on commit afadb8c

Please sign in to comment.