Skip to content

Commit

Permalink
[DoubleGrad PR #2] Adjusted logics of GenerateNodeCreationCodes and G…
Browse files Browse the repository at this point in the history
…enerateForwardDefinition (PaddlePaddle#41016)

* [Refactor] refactored eager_gen.py PR #2

* [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes

* Fixed minor issue

* Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition

* Fixed issues

* Fixed minor issue
  • Loading branch information
jim19930609 committed Mar 29, 2022
1 parent 93a2f56 commit 3983c72
Showing 1 changed file with 89 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,24 @@ class {} : public egr::GradNodeBase {{
FORWARD_FUNCTION_TEMPLATE = \
"""
{} {}({}) {{
{}
{}
{}
// Dygraph Record Event
{}
// AMP Logic
{}
// Get Input AutoGradMeta
{}
// Forward API Call
{}
// Get Output AutoGradMeta
{}
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
// Check Inplace & Bump Inplace Version
{}
{}
// Node Creation
{}
// Returns
return {};
Expand All @@ -174,18 +189,8 @@ class {} : public egr::GradNodeBase {{

FORWARD_BODY_TEMPLATE = \
"""
// Get AutoGradMeta
{}
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
{}
// Forward API Call
{}
{}
{{
{}
{}
if(require_any_grad) {{
{}
egr::EagerUtils::PassStopGradient({});
// Node Construction
Expand All @@ -203,7 +208,6 @@ class {} : public egr::GradNodeBase {{
{}
{}
}}
}}
"""

NAMESPACE_WRAPPER_TEMPLATE = \
Expand Down Expand Up @@ -294,7 +298,6 @@ class {} : public egr::GradNodeBase {{

CHECK_INPLACE_TEMPLATE = \
"""
// Check Inplace
egr::EagerUtils::CheckInplace({}, {}, require_any_grad);\n
"""

Expand Down Expand Up @@ -625,7 +628,7 @@ def SlotNameMatching(self):
f"Generated Backward Grad Output Map: {self.backward_grad_outputs_map}"
)

def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced):
def GenerateNodeCreationCodes(self):
forward_api_name = self.forward_api_name
forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map
Expand All @@ -635,67 +638,14 @@ def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced):
backward_grad_outputs_map = self.backward_grad_outputs_map
backward_attrs_list = self.backward_attrs_list
optional_inputs = self.optional_inputs
inplace_map = self.inplace_map if is_inplaced else {}

# Get Input AutoGradMeta
inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype):
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"

inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
compute_require_grad_args_str = ",".join(compute_require_grad_args_list)

# Get Output AutoGradMeta
outputs_autograd_meta_list = []
# Pass Stop Gradient Args
pass_stop_gradient_args_list = ["false"]
num_fwd_outputs = len(forward_outputs_position_map.keys())
for name, (rtype, pos) in forward_outputs_position_map.items():
for name, (_, _) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
if num_fwd_outputs == 1:
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else:
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"

outputs_autograd_meta_list.append(output_autograd_meta)
pass_stop_gradient_args_list.append(output_autograd_meta_name)

# ComputeRequireGrad & PassStopGradient
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)

# Check Inplace
check_inplace_str = ""
bump_inplace_version_str = ""
if is_inplaced:
for inplace_name in inplace_map.keys():
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
inplace_name, inplace_autograd_meta_name)
bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format(
inplace_name, inplace_name)

# Node Construction
num_backward_inputs = len(forward_outputs_position_map.keys())
num_backward_outputs = len(forward_inputs_position_map.keys())
Expand All @@ -719,6 +669,7 @@ def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced):

# SetTensorWrappers
set_tensor_wrappers_list = []
num_fwd_outputs = len(forward_outputs_position_map.keys())
for name, (atype, is_fwd_input,
pos) in backward_forward_inputs_map.items():
is_optional = (name in optional_inputs)
Expand Down Expand Up @@ -794,13 +745,10 @@ def GenerateNodeCreationCodes(self, forward_call_str, is_inplaced):
node_creation_event_str = f"paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n"

self.node_creation_str = FORWARD_BODY_TEMPLATE.format(
inputs_autograd_meta_str, compute_require_grad_args_str,
check_inplace_str, forward_call_str, bump_inplace_version_str,
node_creation_event_str, outputs_autograd_meta_str,
pass_stop_gradient_args_str, node_construction_str,
set_attributes_str, set_tensor_wrappers_str, set_grad_out_meta_str,
set_edges_str, set_out_rank_str, set_history_str,
set_grad_in_meta_str, set_retain_grad_str)
node_creation_event_str, pass_stop_gradient_args_str,
node_construction_str, set_attributes_str, set_tensor_wrappers_str,
set_grad_out_meta_str, set_edges_str, set_out_rank_str,
set_history_str, set_grad_in_meta_str, set_retain_grad_str)

def run(self):
# Basic Validation Check
Expand Down Expand Up @@ -973,7 +921,64 @@ def GenerateForwardDefinition(self, is_inplaced):
returns_str = ", ".join(returns_list)
returns_str = f"std::make_tuple({returns_str})"

self.GenerateNodeCreationCodes(forward_call_str, is_inplaced)
# Node Creation Pre-Processing
# 1. Get Input AutoGradMeta
inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype):
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"

inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
compute_require_grad_args_str = ",".join(compute_require_grad_args_list)

# 2. Get Output AutoGradMeta
outputs_autograd_meta_list = []
num_fwd_outputs = len(forward_outputs_position_map.keys())
for name, (rtype, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
if num_fwd_outputs == 1:
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else:
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"

outputs_autograd_meta_list.append(output_autograd_meta)

# 3. ComputeRequireGrad & PassStopGradient
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)

# 4. Check Inplace
check_inplace_str = ""
bump_inplace_version_str = ""
if is_inplaced:
for inplace_name in inplace_map.keys():
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
inplace_name, inplace_autograd_meta_name)
bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format(
inplace_name, inplace_name)

self.GenerateNodeCreationCodes()

node_creation_str = self.node_creation_str
dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);"
Expand Down Expand Up @@ -1001,7 +1006,10 @@ def GenerateForwardDefinition(self, is_inplaced):

self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name, inputs_args_definition_str,
dygraph_event_str, amp_logic_str, node_creation_str, returns_str)
dygraph_event_str, amp_logic_str, inputs_autograd_meta_str,
forward_call_str, outputs_autograd_meta_str,
compute_require_grad_args_str, check_inplace_str,
bump_inplace_version_str, node_creation_str, returns_str)
self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n"

logging.info(
Expand Down

0 comments on commit 3983c72

Please sign in to comment.