Skip to content

Commit

Permalink
[Dygraph API] Fix merged_momentum, provide actual inplace operations …
Browse files Browse the repository at this point in the history
…after falling back to CPU
  • Loading branch information
skywalker2012 committed Nov 20, 2023
1 parent cd6a7a4 commit 2437dba
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 9 deletions.
24 changes: 23 additions & 1 deletion paddle/phi/api/yaml/generator/api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,9 +1257,31 @@ def gen_kernel_code(self, kernel_name, code_indent, inplace_flag=False):
transdata2strided += f"""{code_indent} TransStride(dev_ctx, {kernel_out}, backup{i});\n"""
i = i + 1
fallback_kernel_output_trans = ""
for kernel_out in outputs_args:
for idx, kernel_out in enumerate(outputs_args):
fallback_kernel_output_trans += f"""
{code_indent} TransDataBackend({kernel_out}, kernel_backend, {kernel_out});"""
if (
self.outputs['types'][idx] == 'std::vector<Tensor>'
and self.outputs['names'][idx] in self.inplace_map
):
target_input = self.inplace_map[self.outputs['names'][idx]]
if (
self.inplace_map[self.outputs['names'][idx]]
in self.optional_vars
):
fallback_kernel_output_trans += f"""
{code_indent} if ({target_input}) {{
{code_indent} for (size_t i = 0; i < {target_input}->size(); ++i) {{
{code_indent} auto target_ptr = static_cast<phi::DenseTensor*>({target_input}->at(i).impl().get());
{code_indent} *target_ptr = *{kernel_out}.at(i);
{code_indent} }}
{code_indent} }}"""
else:
fallback_kernel_output_trans += f"""
{code_indent} for (size_t i = 0; i < {target_input}.size(); ++i) {{
{code_indent} auto target_ptr = static_cast<phi::DenseTensor*>({target_input}.at(i).impl().get());
{code_indent} *target_ptr = *{kernel_out}.at(i);
{code_indent} }}"""
return f"""
{code_indent} VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
{code_indent} auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
Expand Down
54 changes: 46 additions & 8 deletions paddle/phi/api/yaml/generator/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,29 @@ def gene_return_code(self):
]
return 'return std::make_tuple(' + ", ".join(selected_code) + ');'

def gene_fallback_code_after_gene_output_of_vector(
self, code_indent, output_idx, is_inplace, is_optional
):
fallback_code = ""
if is_inplace and is_optional:
fallback_code = f"""
{code_indent} if (kernel_result.has_fallback_cpu) {{
{code_indent} for (size_t i = 0; i < kernel_out_{output_idx}.size(); ++i) {{
{code_indent} kernel_out_{output_idx}[i] = const_cast<phi::DenseTensor*>({PREFIX_TENSOR_NAME}{self.inplace_map[self.outputs['names'][output_idx]]}->at(i));
{code_indent} }}
{code_indent} }}"""
elif is_inplace:
fallback_code = f"""
{code_indent} if (kernel_result.has_fallback_cpu) {{
{code_indent} for (size_t i = 0; i < kernel_out_{output_idx}.size(); ++i) {{
{code_indent} kernel_out_{output_idx}[i] = const_cast<phi::DenseTensor*>({PREFIX_TENSOR_NAME}{self.inplace_map[self.outputs['names'][output_idx]]}[i]);
{code_indent} }}
{code_indent} }}"""
else:
fallback_code = ""

return fallback_code

def gene_output(
self,
out_dtype_list,
Expand Down Expand Up @@ -271,14 +294,29 @@ def gene_output(
"SetInplaceOptionalVectorKernelOutput"
)
get_out_code = f"std::get<{i}>(api_output)"
output_create = (
output_create
+ f"""
{code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, {get_out_code});
{code_indent} if (kernel_result.has_fallback_cpu) {{
{code_indent} TransDataBackend(kernel_out_{i}, actual_kernel_backend, kernel_out_{i});
{code_indent} }}"""
)
output_create = (
output_create
+ f"""
{code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, {get_out_code});"""
+ self.gene_fallback_code_after_gene_output_of_vector(
code_indent, i, True, True
)
)
else:
output_create = (
output_create
+ f"""
{code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, {get_out_code});"""
+ self.gene_fallback_code_after_gene_output_of_vector(
code_indent, i, True, False
)
)
else:
output_create = (
output_create
+ f"""
{code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, {get_out_code});"""
)

else:
output_create = (
Expand Down

0 comments on commit 2437dba

Please sign in to comment.