Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] refactored eager_gen.py PR #2 #40907

Merged
merged 1 commit into from Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -46,6 +46,10 @@
#############################
### File Reader Helpers ###
#############################
def AssertMessage(lhs_str, rhs_str):
return f"lhs: {lhs_str}, rhs: {rhs_str}"


def ReadFwdFile(filepath):
f = open(filepath, 'r')
contents = yaml.load(f, Loader=yaml.FullLoader)
Expand All @@ -58,10 +62,10 @@ def ReadBwdFile(filepath):
contents = yaml.load(f, Loader=yaml.FullLoader)
ret = {}
for content in contents:
assert 'backward_api' in content.keys(), AssertMessage('backward_api',
content.keys())
if 'backward_api' in content.keys():
api_name = content['backward_api']
else:
assert False

ret[api_name] = content
f.close()
Expand Down Expand Up @@ -221,7 +225,7 @@ def ParseYamlReturns(string):
), f"The return type {ret_type} in yaml config is not supported in yaml_types_mapping."
ret_type = yaml_types_mapping[ret_type]

assert "Tensor" in ret_type
assert "Tensor" in ret_type, AssertMessage("Tensor", ret_type)
ret_name = RemoveSpecialSymbolsInName(ret_name)
returns_list.append([ret_name, ret_type, i])

Expand Down
Expand Up @@ -16,6 +16,7 @@
import re
import argparse
import os
import logging
from codegen_utils import core_ops_returns_info, core_ops_args_info, core_ops_args_type_info
from codegen_utils import yaml_types_mapping
from codegen_utils import ReadFwdFile, ReadBwdFile
Expand All @@ -30,6 +31,7 @@
from codegen_utils import ParseYamlForward, ParseYamlBackward
from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase
from codegen_utils import ops_to_fill_zero_for_empty_grads
from codegen_utils import AssertMessage


###########
Expand Down Expand Up @@ -398,14 +400,21 @@ def DygraphYamlValidationCheck(self):
forward_api_contents = self.forward_api_contents
grad_api_contents = self.grad_api_contents

assert 'api' in forward_api_contents.keys()
assert 'args' in forward_api_contents.keys()
assert 'output' in forward_api_contents.keys()
assert 'backward' in forward_api_contents.keys()

assert 'args' in grad_api_contents.keys()
assert 'output' in grad_api_contents.keys()
assert 'forward' in grad_api_contents.keys()
assert 'api' in forward_api_contents.keys(
), "Unable to find \"api\" in api.yaml"
assert 'args' in forward_api_contents.keys(
), "Unable to find \"args\" in api.yaml"
assert 'output' in forward_api_contents.keys(
), "Unable to find \"output\" in api.yaml"
assert 'backward' in forward_api_contents.keys(
), "Unable to find \"backward\" in api.yaml"

assert 'args' in grad_api_contents.keys(
), "Unable to find \"args\" in backward.yaml"
assert 'output' in grad_api_contents.keys(
), "Unable to find \"output\" in backward.yaml"
assert 'forward' in grad_api_contents.keys(
), "Unable to find \"forward\" in backward.yaml"

def ForwardsValidationCheck(self):
forward_inputs_list = self.forward_inputs_list
Expand All @@ -424,8 +433,10 @@ def ForwardsValidationCheck(self):
orig_input_type = orig_forward_inputs_list[i][1]
orig_input_pos = orig_forward_inputs_list[i][2]

assert forward_input_type == orig_input_type
assert forward_input_pos == orig_input_pos
assert forward_input_type == orig_input_type, AssertMessage(
forward_input_type, orig_input_type)
assert forward_input_pos == orig_input_pos, AssertMessage(
forward_input_pos, orig_input_pos)

for i in range(len(forward_attrs_list)):
orig_attr_name = orig_forward_attrs_list[i][0]
Expand All @@ -436,18 +447,23 @@ def ForwardsValidationCheck(self):
forward_attr_type = forward_attrs_list[i][1]
forward_attr_default = forward_attrs_list[i][2]
forward_attr_pos = forward_attrs_list[i][3]
assert orig_attr_type == forward_attr_type
assert orig_attr_default == forward_attr_default
assert orig_attr_pos == forward_attr_pos
assert orig_attr_type == forward_attr_type, AssertMessage(
orig_attr_type, forward_attr_type)
assert orig_attr_default == forward_attr_default, AssertMessage(
orig_attr_default, forward_attr_default)
assert orig_attr_pos == forward_attr_pos, AssertMessage(
orig_attr_pos, forward_attr_pos)

for i in range(len(forward_returns_list)):
orig_return_type = orig_forward_returns_list[i][1]
orig_return_pos = orig_forward_returns_list[i][2]
forward_return_type = forward_returns_list[i][1]
forward_return_pos = forward_returns_list[i][2]

assert orig_return_type == forward_return_type
assert orig_return_pos == forward_return_pos
assert orig_return_type == forward_return_type, AssertMessage(
orig_return_type, forward_return_type)
assert orig_return_pos == forward_return_pos, AssertMessage(
orig_return_pos, forward_return_pos)

# Check Order: Inputs, Attributes
max_input_position = -1
Expand All @@ -456,7 +472,8 @@ def ForwardsValidationCheck(self):

max_attr_position = -1
for _, _, _, pos in forward_attrs_list:
assert pos > max_input_position
assert pos > max_input_position, AssertMessage(pos,
max_input_position)
max_attr_position = max(max_attr_position, pos)

def BackwardValidationCheck(self):
Expand All @@ -471,12 +488,14 @@ def BackwardValidationCheck(self):

max_grad_tensor_position = -1
for _, (_, _, pos) in backward_grad_inputs_map.items():
assert pos > max_fwd_input_position
assert pos > max_fwd_input_position, AssertMessage(
pos, max_grad_tensor_position)
max_grad_tensor_position = max(max_grad_tensor_position, pos)

max_attr_position = -1
for _, _, _, pos in backward_attrs_list:
assert pos > max_grad_tensor_position
assert pos > max_grad_tensor_position, AssertMessage(
pos, max_grad_tensor_position)
max_attr_position = max(max_attr_position, pos)

def IntermediateValidationCheck(self):
Expand All @@ -491,7 +510,8 @@ def IntermediateValidationCheck(self):
len(forward_returns_list))
for ret_name, _, pos in forward_returns_list:
if ret_name in intermediate_outputs:
assert pos in intermediate_positions
assert pos in intermediate_positions, AssertMessage(
pos, intermediate_positions)

def CollectBackwardInfo(self):
forward_api_contents = self.forward_api_contents
Expand All @@ -505,9 +525,12 @@ def CollectBackwardInfo(self):

self.backward_inputs_list, self.backward_attrs_list, self.backward_returns_list = ParseYamlBackward(
backward_args_str, backward_returns_str)
print("Parsed Backward Inputs List: ", self.backward_inputs_list)
print("Prased Backward Attrs List: ", self.backward_attrs_list)
print("Parsed Backward Returns List: ", self.backward_returns_list)

logging.info(
f"Parsed Backward Inputs List: {self.backward_inputs_list}")
logging.info(f"Prased Backward Attrs List: {self.backward_attrs_list}")
logging.info(
f"Parsed Backward Returns List: {self.backward_returns_list}")

def CollectForwardInfoFromBackwardContents(self):

Expand All @@ -530,7 +553,9 @@ def SlotNameMatching(self):
backward_fwd_name = FindForwardName(backward_input_name)
if backward_fwd_name:
# Grad Input
assert backward_fwd_name in forward_outputs_position_map.keys()
assert backward_fwd_name in forward_outputs_position_map.keys(
), AssertMessage(backward_fwd_name,
forward_outputs_position_map.keys())
matched_forward_output_type = forward_outputs_position_map[
backward_fwd_name][0]
matched_forward_output_pos = forward_outputs_position_map[
Expand All @@ -556,17 +581,18 @@ def SlotNameMatching(self):
backward_input_type, False, backward_input_pos
]
else:
assert False, backward_input_name
assert False, f"Cannot find {backward_input_name} in forward position map"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe raise Error is better?


for backward_output in backward_returns_list:
backward_output_name = backward_output[0]
backward_output_type = backward_output[1]
backward_output_pos = backward_output[2]

backward_fwd_name = FindForwardName(backward_output_name)
assert backward_fwd_name is not None
assert backward_fwd_name is not None, f"Detected {backward_fwd_name} = None"
assert backward_fwd_name in forward_inputs_position_map.keys(
), f"Unable to find {backward_fwd_name} in forward inputs"
), AssertMessage(backward_fwd_name,
forward_inputs_position_map.keys())

matched_forward_input_type = forward_inputs_position_map[
backward_fwd_name][0]
Expand All @@ -577,12 +603,15 @@ def SlotNameMatching(self):
backward_output_type, matched_forward_input_pos,
backward_output_pos
]
print("Generated Backward Fwd Input Map: ",
self.backward_forward_inputs_map)
print("Generated Backward Grad Input Map: ",
self.backward_grad_inputs_map)
print("Generated Backward Grad Output Map: ",
self.backward_grad_outputs_map)
logging.info(
f"Generated Backward Fwd Input Map: {self.backward_forward_inputs_map}"
)
logging.info(
f"Generated Backward Grad Input Map: {self.backward_grad_inputs_map}"
)
logging.info(
f"Generated Backward Grad Output Map: {self.backward_grad_outputs_map}"
)

def GenerateNodeDeclaration(self):
forward_op_name = self.forward_api_name
Expand Down Expand Up @@ -642,7 +671,7 @@ def GenerateNodeDeclaration(self):
set_tensor_wrapper_methods_str, set_attribute_methods_str,
tensor_wrapper_members_str, attribute_members_str)

print("Generated Node Declaration: ", self.node_declaration_str)
logging.info(f"Generated Node Declaration: {self.node_declaration_str}")

def GenerateNodeDefinition(self):
namespace = self.namespace
Expand Down Expand Up @@ -710,7 +739,7 @@ def GenerateNodeDefinition(self):
grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace,
backward_api_name, grad_api_args_str, returns_str)

print("Generated Node Definition: ", self.node_definition_str)
logging.info(f"Generated Node Definition: {self.node_definition_str}")

def GenerateForwardDefinition(self, is_inplaced):
namespace = self.namespace
Expand Down Expand Up @@ -813,8 +842,10 @@ def GenerateForwardDefinition(self, is_inplaced):
dygraph_event_str, node_creation_str, returns_str)
self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n"

print("Generated Forward Definition: ", self.forward_definition_str)
print("Generated Forward Declaration: ", self.forward_declaration_str)
logging.info(
f"Generated Forward Definition: {self.forward_definition_str}")
logging.info(
f"Generated Forward Declaration: {self.forward_declaration_str}")

def GenerateNodeCreationCodes(self, forward_call_str):
forward_api_name = self.forward_api_name
Expand Down Expand Up @@ -921,7 +952,8 @@ def GenerateNodeCreationCodes(self, forward_call_str):
else:
if num_fwd_outputs > 1:
# Aligned with forward output position
assert name in forward_outputs_position_map.keys()
assert name in forward_outputs_position_map.keys(
), AssertMessage(name, forward_outputs_position_map.keys())
fwd_output_pos = forward_outputs_position_map[name][1]
tw_name = f"std::get<{fwd_output_pos}>(api_result)"
else:
Expand Down Expand Up @@ -1114,7 +1146,8 @@ def GetBackwardAPIContents(self, forward_api_contents):
if 'backward' not in forward_api_contents.keys(): return None

backward_api_name = forward_api_contents['backward']
assert backward_api_name in grad_api_dict.keys()
assert backward_api_name in grad_api_dict.keys(), AssertMessage(
backward_api_name, grad_api_dict.keys())
backward_api_contents = grad_api_dict[backward_api_name]

return backward_api_contents
Expand Down