Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DoubleGrad PR #8] Enabled triple grads for sigmoid and matmul #41387

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
5e95fd1
[Refactor] refactored eager_gen.py PR #2
jim19930609 Mar 24, 2022
214ad36
[DoubleGrad PR #1] Decoupled code generation logics for Dygraph Forwa…
jim19930609 Mar 25, 2022
14ebc19
Fixed minor issue
jim19930609 Mar 25, 2022
62179ba
Merged develop
jim19930609 Mar 25, 2022
2a83dd9
Merged develop from [#40935] Bug fix for intermediate support in Yaml
jim19930609 Mar 28, 2022
9e7ecf6
Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefin…
jim19930609 Mar 28, 2022
b945424
Fixed issues
jim19930609 Mar 28, 2022
033913c
Merged develop
jim19930609 Mar 28, 2022
84ea03e
Supported higher-order grad node generation
jim19930609 Mar 28, 2022
f29de18
Merge branch 'refactored_node_creation_codes' of https://github.com/j…
jim19930609 Mar 28, 2022
6117fd1
[DoubleGrad PR #4] Supported higher-order GradNode generation
jim19930609 Mar 29, 2022
8785075
Merged develop
jim19930609 Mar 29, 2022
f647fb4
Merged develop
jim19930609 Mar 30, 2022
244f6da
[DoubleGrad #4] Bug Fixes to Double Grad Node Generation
jim19930609 Mar 30, 2022
d17aaef
Merged develop
jim19930609 Mar 30, 2022
4001488
Fixed yaml typo
jim19930609 Mar 30, 2022
564f04b
Merge branch 'support_double_grad_node_generation' of https://github.…
jim19930609 Mar 30, 2022
e26b572
Fixed yaml typo
jim19930609 Mar 30, 2022
abfef59
Merge branch 'support_double_grad_node_generation' of https://github.…
jim19930609 Mar 30, 2022
a16f769
merged develop
jim19930609 Mar 30, 2022
ca4fe62
merged develop
jim19930609 Mar 30, 2022
1b5fdb5
fixed minor issues
jim19930609 Mar 30, 2022
c542869
merged develop
jim19930609 Mar 31, 2022
a3ce1a9
[DoubleGrad PR #5] Enabled gradient computations for grad_tensors pas…
jim19930609 Mar 31, 2022
ec2abc2
Merge branch 'optimized_double_grad' of https://github.com/jim1993060…
jim19930609 Mar 31, 2022
1239f69
Fixed minor issue
jim19930609 Mar 31, 2022
57de577
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jim19930609 Mar 31, 2022
7e6b1e2
Merge branch 'optimized_double_grad' of https://github.com/jim1993060…
jim19930609 Apr 1, 2022
499b3aa
Fixed CI-Inference issue
jim19930609 Apr 1, 2022
dc1adae
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jim19930609 Apr 1, 2022
3359556
Fixed CI-inference issues
jim19930609 Apr 2, 2022
4c2ece1
[DoubleGrad PR #7] paddle.grad() to copy backward graph before backwa…
jim19930609 Apr 2, 2022
9b19fd8
Merge branch 'paddle_grad_bug_fix' of https://github.com/jim19930609/…
jim19930609 Apr 2, 2022
769f623
Fixed minor issues
jim19930609 Apr 2, 2022
ee13e8a
Fixed issue with backward graph construction logic
jim19930609 Apr 2, 2022
a86804b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jim19930609 Apr 2, 2022
a44667e
Fixed implementation issues with backward graph reconstruction
jim19930609 Apr 3, 2022
f1fbf49
Fixed unittest issue
jim19930609 Apr 3, 2022
0502d22
Fixed issues
jim19930609 Apr 3, 2022
ab17c93
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jim19930609 Apr 4, 2022
0da4325
[DoubleGrad PR #8] Enabled triple grads for sigmoid and matmul
jim19930609 Apr 4, 2022
a4644c1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jim19930609 Apr 4, 2022
1cdf6a3
Merge branch 'paddle_grad_bug_fix_v2' of https://github.com/jim199306…
jim19930609 Apr 4, 2022
f7104b8
Fixed issues with phi kernel
jim19930609 Apr 4, 2022
7a8b254
Added triple grad test case
jim19930609 Apr 4, 2022
cdd3b53
merged develop
jim19930609 Apr 4, 2022
7f10421
Fixed minor issue
jim19930609 Apr 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
########################
### Global Variables ###
########################
ops_to_fill_zero_for_empty_grads = set(
["split_grad", "rnn_grad", "matmul_double_grad"])
ops_to_fill_zero_for_empty_grads = set([
"split_grad", "rnn_grad", "matmul_double_grad", "matmul_triple_grad",
"sigmoid_triple_grad"
])

# For API dispatch used at python-level
# { op_name : [arg_name, ...] }
Expand Down Expand Up @@ -171,12 +173,6 @@ def GetForwardFunctionName(string):
return f"{string}_final_state_dygraph_function"


def TransformGradVarNameForDoubleGradGeneration(string):
if IsGradName(string):
string = "grad_" + string[:-5]
return string


def GetIndent(num):
tab = " "
return "".join([tab for i in range(num)])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from codegen_utils import ParseYamlForward, ParseYamlBackward
from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase
from codegen_utils import ops_to_fill_zero_for_empty_grads
from codegen_utils import TransformGradVarNameForDoubleGradGeneration
from codegen_utils import AssertMessage, GetIndent


Expand Down Expand Up @@ -483,10 +482,8 @@ def ForwardsValidationCheck(self):
orig_forward_returns_list = self.orig_forward_returns_list

for i in range(len(forward_inputs_list)):
forward_input_name = forward_inputs_list[i][0]
forward_input_type = forward_inputs_list[i][1]
forward_input_pos = forward_inputs_list[i][2]
orig_input_name = orig_forward_inputs_list[i][0]
orig_input_type = orig_forward_inputs_list[i][1]
orig_input_pos = orig_forward_inputs_list[i][2]

Expand All @@ -496,11 +493,9 @@ def ForwardsValidationCheck(self):
forward_input_pos, orig_input_pos)

for i in range(len(forward_attrs_list)):
orig_attr_name = orig_forward_attrs_list[i][0]
orig_attr_type = orig_forward_attrs_list[i][1]
orig_attr_default = orig_forward_attrs_list[i][2]
orig_attr_pos = orig_forward_attrs_list[i][3]
forward_attr_name = forward_attrs_list[i][0]
forward_attr_type = forward_attrs_list[i][1]
forward_attr_default = forward_attrs_list[i][2]
forward_attr_pos = forward_attrs_list[i][3]
Expand Down Expand Up @@ -1132,11 +1127,20 @@ def __init__(self,
DygraphFunctionGeneratorBase.__init__(self, forward_api_contents,
grad_api_contents, namespace)

# Record name mapping from forward_api_name to grad_api_names
self.to_next_grad_name_mapping = {} # {name : name}

# Generated Results
self.node_declaration_str = ""
self.node_definition_str = ""
self.next_grad_api_contents = next_grad_api_contents

def TransformToNextGradName(self, string):
name_mapping = self.to_next_grad_name_mapping
if string in name_mapping.keys():
return name_mapping[string]
return string

def ResetOptionalInputs(self):
namespace = self.namespace
grad_api_contents = self.grad_api_contents
Expand All @@ -1146,6 +1150,22 @@ def ResetOptionalInputs(self):

self.optional_inputs = base_generator.optional_inputs

def RecordGrad2NextGradNameMapping(self, next_node_generator):
next_orig_inputs_list = next_node_generator.orig_forward_inputs_list
next_orig_returns_list = next_node_generator.orig_forward_returns_list

next_forward_inputs_list = next_node_generator.forward_inputs_list
next_forward_returns_list = next_node_generator.forward_returns_list
for i in range(len(next_orig_inputs_list)):
grad_name = next_orig_inputs_list[i][0]
next_forward_name = next_forward_inputs_list[i][0]
self.to_next_grad_name_mapping[grad_name] = next_forward_name

for i in range(len(next_orig_returns_list)):
grad_ret_name = next_orig_returns_list[i][0]
next_ret_name = next_forward_returns_list[i][0]
self.to_next_grad_name_mapping[grad_ret_name] = next_ret_name

def GenerateHigherOrderNodeCreationCode(self):
namespace = self.namespace
grad_api_contents = self.grad_api_contents
Expand All @@ -1163,6 +1183,8 @@ def GenerateHigherOrderNodeCreationCode(self):
next_node_generator.GenerateNodeCreationCodes()
grad_node_creation_str = next_node_generator.node_creation_str

self.RecordGrad2NextGradNameMapping(next_node_generator)

return grad_node_creation_str

def GenerateNodeDeclaration(self):
Expand Down Expand Up @@ -1252,8 +1274,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
for name, (_, is_fwd_input,
grad_api_position), in backward_forward_inputs_map.items():
tensor_wrapper_name = GetSavedName(name)
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

is_optional = (name in self.optional_inputs)
tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());"
Expand All @@ -1273,8 +1294,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
# Grad Ins from grads
for name, (ttype, fwd_position,
grad_api_position) in backward_grad_inputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

is_optional = (name in self.optional_inputs)
if IsPlainTensorType(ttype):
Expand Down Expand Up @@ -1315,8 +1335,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
num_outputs = len(backward_grad_outputs_map.keys())
for name, (ttype, fwd_position,
grad_api_position) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

if num_outputs == 1:
get_tensor_str = f"{indent}auto& {transformed_tensor_name} = grad_api_result;"
Expand All @@ -1338,8 +1357,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos,
grad_api_position) in backward_grad_inputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
Expand All @@ -1357,8 +1375,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):

# 2. Get TensorWrapper AutoGradMeta
for name, (ttype, _, pos), in backward_forward_inputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
Expand All @@ -1381,8 +1398,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
outputs_autograd_meta_list = []
num_fwd_outputs = len(backward_grad_outputs_map.keys())
for name, (rtype, pos, _) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

output_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
Expand Down Expand Up @@ -1416,8 +1432,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
returns_str = f"{indent}std::vector<std::vector<paddle::experimental::Tensor>> returns({slot_num_bwd_outputs});\n"
for name, (ttype, fwd_position,
grad_api_position) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

# Infer Grad API Return Type
if num_bwd_outputs == 1:
Expand All @@ -1440,6 +1455,9 @@ def GenerateNodeDefinition(self, grad_node_creation_str):

grad_node_name = GetGradNodeName(forward_api_name)

if len(grad_node_creation_str) == 0:
grad_node_creation_str = f"if(create_graph) VLOG(3) << \"Higher order grad node for {grad_node_name} has not been implemented yet.\";"

self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format(
grad_node_name, fill_zero_str, get_grad_in_args_str, grad_node_name,
grad_function_call_str, get_outputs_str, inputs_autograd_meta_str,
Expand All @@ -1456,11 +1474,11 @@ def run(self):
#####################
## Code Generation ##
#####################
self.GenerateNodeDeclaration()

# Higher-order GradNode generation
grad_node_creation_str = self.GenerateHigherOrderNodeCreationCode()

self.GenerateNodeDeclaration()

self.GenerateNodeDefinition(grad_node_creation_str)


Expand Down
48 changes: 48 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,54 @@ void GeneralTernaryGradInferMeta(const MetaTensor& x,
dz->share_meta(z);
}
}
void GeneralQuaternaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
const MetaTensor& k,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz,
MetaTensor* dk) {
if (dx) {
dx->share_meta(x);
}
if (dy) {
dy->share_meta(y);
}
if (dz) {
dz->share_meta(z);
}
if (dk) {
dk->share_meta(k);
}
}

void GeneralQuinaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
const MetaTensor& k,
const MetaTensor& l,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz,
MetaTensor* dk,
MetaTensor* dl) {
if (dx) {
dx->share_meta(x);
}
if (dy) {
dy->share_meta(y);
}
if (dz) {
dz->share_meta(z);
}
if (dk) {
dk->share_meta(k);
}
if (dl) {
dl->share_meta(l);
}
}

void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx) {
if (dx) {
Expand Down
20 changes: 20 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,26 @@ void GeneralTernaryGradInferMeta(const MetaTensor& x,
MetaTensor* dy,
MetaTensor* dz);

void GeneralQuaternaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
const MetaTensor& k,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz,
MetaTensor* dk);

void GeneralQuinaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
const MetaTensor& k,
const MetaTensor& l,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz,
MetaTensor* dk,
MetaTensor* dl);

void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx);

void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/activation_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,18 @@ void EluDoubleGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void SigmoidDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& ddx,
const DenseTensor& dout,
const DenseTensor& ddx,
DenseTensor* dout_new,
DenseTensor* ddout);

template <typename T, typename Context>
void SigmoidTripleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& ddx,
const DenseTensor& dout,
const DenseTensor& d_ddout,
const DenseTensor& ddx,
const DenseTensor& d_dout_new,
const DenseTensor& d_ddout,
DenseTensor* d_out_new,
DenseTensor* d_dout,
DenseTensor* d_ddx);
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/impl/activation_grad_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ void LogitGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void SigmoidDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& ddx,
const DenseTensor& dout,
const DenseTensor& ddx,
DenseTensor* dout_new,
DenseTensor* ddout) {
if (dout_new) {
Expand All @@ -262,10 +262,10 @@ void SigmoidDoubleGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void SigmoidTripleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& ddx,
const DenseTensor& dout,
const DenseTensor& d_ddout,
const DenseTensor& ddx,
const DenseTensor& d_dout_new,
const DenseTensor& d_ddout,
DenseTensor* d_out_new,
DenseTensor* d_dout,
DenseTensor* d_ddx) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/ops/compat/activation_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,13 @@ KernelSignature TanhTripleGradOpArgumentMapping(
KernelSignature SigmoidDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"sigmoid_double_grad", {"Out", "DDX", "DOut"}, {}, {"DOutNew", "DDOut"});
"sigmoid_double_grad", {"Out", "DOut", "DDX"}, {}, {"DOutNew", "DDOut"});
}

KernelSignature SigmoidTripleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("sigmoid_triple_grad",
{"Out", "DDX", "DOut", "D_DDOut", "D_DOut_New"},
{"Out", "DOut", "DDX", "D_DOut_New", "D_DDOut"},
{},
{"D_OutNew", "D_DOut", "D_DDx"});
}
Expand Down
Loading