Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Pten】Change output type of backward_api from tuple to vector #39229

Merged
merged 1 commit into from Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 37 additions & 6 deletions python/paddle/utils/code_gen/api_gen.py
Expand Up @@ -31,7 +31,12 @@ def __init__(self, api_item_yaml):
# names : [], list of attribute names
# attr_info : { attr_name : (type, default_values)}
self.args = gen_utils.parse_args(self.api, api_item_yaml['args'])
self.output = api_item_yaml['output']
self.out_type_list, _ = gen_utils.parse_output(self.api,
api_item_yaml['output'])
self.return_type = self.out_type_list[0] if len(
self.out_type_list) == 1 else "std::tuple<" + ",".join(
self.out_type_list) + ">"

self.is_base_api = True
if 'invoke' in api_item_yaml:
self.is_base_api = False
Expand All @@ -54,18 +59,44 @@ def __init__(self, api_item_yaml):

def gene_api_declaration(self):
return f"""
PADDLE_API {self.output} {self.api}({self.args['args_declare']});
PADDLE_API {self.return_type} {self.api}({self.args['args_declare']});
"""

def gene_output(self, output_type_list):
kernel_output = ""
output_create = ""

if len(output_type_list) == 1:
kernel_output = 'dense_out'
output_create = f"""
{self.return_type} out;
auto dense_out = SetKernelOutput(out_meta, kernel_backend, &out);"""

elif len(output_type_list) > 1:
output_create = f"""
{self.return_type} out;"""

for i in range(len(output_type_list)):
kernel_output = kernel_output + f'dense_out_{i}, '
output_create = output_create + f"""
auto dense_out_{i} = SetKernelOutput(std::get<{i}>(out_meta), kernel_backend, &std::get<{i}>(out));"""

kernel_output = kernel_output[:-2]
else:
raise ValueError(
"{} : Output error: the output should not be empty.".format(
self.api))

return kernel_output, output_create

def gene_api_code(self):
if self.is_base_api:
input_tensors, kernel_args = gen_utils.get_kernel_args(
self.args['inputs']['names'], self.args['attrs'],
self.kernel['param'])
out_type, _ = gen_utils.parse_output(self.api, self.output)
outputs_args, output_create = gen_utils.gene_output(out_type)
outputs_args, output_create = self.gene_output(self.out_type_list)
return f"""
PADDLE_API {self.output} {self.api}({self.args["args_define"]}) {{
PADDLE_API {self.return_type} {self.api}({self.args["args_define"]}) {{
{gen_utils.gene_kernel_select(self.api, self.args['inputs']['names'], self.args['attrs'], self.kernel)}

auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
Expand All @@ -82,7 +113,7 @@ def gene_api_code(self):

else:
return f"""
PADDLE_API {self.output} {self.api}({self.args["args_define"]}) {{
PADDLE_API {self.return_type} {self.api}({self.args["args_define"]}) {{
return {self.invoke};
}}
"""
Expand Down
53 changes: 42 additions & 11 deletions python/paddle/utils/code_gen/backward_api_gen.py
Expand Up @@ -23,9 +23,11 @@
class BackwardAPI:
def __init__(self, backward_item_yaml):
self.backward_api = backward_item_yaml['backward_api']
self.args, self.output_type, self.return_comment = self.parse_and_check_args(
self.args, self.output_type_list, self.return_comment = self.parse_and_check_args(
backward_item_yaml['forward'], backward_item_yaml['args'],
backward_item_yaml['output'])
self.return_type = self.output_type_list[0] if len(
self.output_type_list) == 1 else "std::vector<std::vector<Tensor>>"

self.is_base_api = True
if 'invoke' in backward_item_yaml:
Expand Down Expand Up @@ -81,36 +83,65 @@ def parse_and_check_args(self, forward_config, args_config, output_config):
Please check the args of {self.backward_api} in yaml."

# check the output of backward
output_type, return_comment = gen_utils.parse_output(self.backward_api,
output_config)
assert output_type.count('Tensor') <= len(fw_inputs['names']), \
out_type_list, return_comment = gen_utils.parse_output(
self.backward_api, output_config)
assert len(out_type_list) <= len(fw_inputs['names']), \
f"{self.backward_api} : Output error: The number of ouputs should be less then the number of inputs of forward api. \
Please check the output of {self.backward_api} in yaml."

return bw_args, output_type, return_comment
return bw_args, out_type_list, return_comment

def gene_api_declaration(self):
if self.return_comment:
return f"""
// {self.return_comment}
{self.output_type} {self.backward_api}({self.args['args_declare']});
{self.return_type} {self.backward_api}({self.args['args_declare']});
"""

else:
return f"""
{self.output_type} {self.backward_api}({self.args['args_declare']});
{self.return_type} {self.backward_api}({self.args['args_declare']});
"""

def gene_output(self, output_type_list):
kernel_output = ""
output_create = ""

if len(output_type_list) == 1:
return_type = output_type_list[0]
kernel_output = 'dense_out'
output_create = f"""
{self.return_type} out;
auto dense_out = SetKernelOutput(out_meta, kernel_backend, &out);"""

elif len(output_type_list) > 1:
output_create = f"""
{self.return_type} out;"""

for i, out_type_item in enumerate(output_type_list):
kernel_output = kernel_output + f'dense_out_{i}, '
get_out_code = f'&out[{i}][0]' if out_type_item == 'Tensor' else f'&out[{i}]'
output_create = output_create + f"""
auto dense_out_{i} = SetKernelOutput(std::get<{i}>(out_meta), kernel_backend, {get_out_code});"""

kernel_output = kernel_output[:-2]
else:
raise ValueError(
"{} : Output error: the output should not be empty.".format(
self.backward_api))

return kernel_output, output_create

def gene_api_code(self):
if self.is_base_api:
input_tensors, kernel_args = gen_utils.get_kernel_args(
self.args['inputs']['names'], self.args['attrs'],
self.kernel['param'])
outputs_args, output_create = gen_utils.gene_output(
self.output_type)
outputs_args, output_create = self.gene_output(
self.output_type_list)
return f"""
// {self.return_comment}
{self.output_type} {self.backward_api}({self.args["args_define"]}) {{
{self.return_type} {self.backward_api}({self.args["args_define"]}) {{
{gen_utils.gene_kernel_select(self.backward_api, self.args['inputs']['names'], self.args['attrs'], self.kernel)}

auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
Expand Down Expand Up @@ -143,7 +174,7 @@ def adjust_name(matched):
params_code = self.args["args_define"]
return f"""
// {self.return_comment}
{self.output_type} {self.backward_api}({params_code}) {{
{self.return_type} {self.backward_api}({params_code}) {{
return {invoke_code};
}}
"""
Expand Down
28 changes: 3 additions & 25 deletions python/paddle/utils/code_gen/gen_utils.py
Expand Up @@ -124,7 +124,7 @@ def parse_output_item(output_item):

if len(temp_list) == 1:
out_type, out_name = parse_output_item(temp_list[0])
return out_type, out_name
return [out_type], out_name
else:
out_type_list = []
out_name_list = []
Expand All @@ -133,8 +133,7 @@ def parse_output_item(output_item):
out_type_list.append(out_type)
out_name_list.append(out_name)

return "std::tuple<" + ",".join(out_type_list) + ">", ", ".join(
out_name_list)
return out_type_list, ", ".join(out_name_list)


def gene_kernel_select(api, input_names, attrs, kernel) -> str:
Expand Down Expand Up @@ -241,7 +240,7 @@ def gene_kernel_select(api, input_names, attrs, kernel) -> str:

if len(input_names) > 0:
kernel_select_code = kernel_select_code + f"""
if (kernel_backend == Backend::UNDEFINED
if (kernel_backend == Backend::UNDEFINED
|| kernel_layout == DataLayout::UNDEFINED
|| kernel_data_type == DataType::UNDEFINED ) {{
auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args});
Expand Down Expand Up @@ -315,24 +314,3 @@ def get_kernel_args(input_names, attrs, kernel_param):
else:
kernel_args = kernel_args + str(param) + ", "
return input_tensor_code, kernel_args[:-2]


def gene_output(output_type):
kernel_output = ""
output_create = f"""
{output_type} out;"""

if output_type == 'Tensor' or output_type == 'std::vector<Tensor>':
kernel_output = 'dense_out'
output_create = output_create + """
auto dense_out = SetKernelOutput(out_meta, kernel_backend, &out);"""
elif re.match(r'std::tuple<.*>$', output_type):
out_num = output_type.count('Tensor')
for i in range(out_num):
kernel_output = kernel_output + f'dense_out_{i}, '
output_create = output_create + f"""
auto dense_out_{i} = SetKernelOutput(std::get<{i}>(out_meta), kernel_backend, &std::get<{i}>(out));"""

kernel_output = kernel_output[:-2]

return kernel_output, output_create