Skip to content

Commit

Permalink
refine optional code for amp
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbo9674 committed Mar 26, 2022
1 parent fc0b3fe commit 1e00955
Showing 1 changed file with 4 additions and 3 deletions.
Expand Up @@ -801,10 +801,10 @@ def GenerateForwardDefinition(self, is_inplaced):
if is_optional:
arg_str = f"const paddle::optional<const paddle::experimental::Tensor&> {name}"
amp_tensors_vector_optional_list.append(
f"if ({name}.is_initialized()) amp_tensors_vector.push_back({name}.get());\n"
f"if ({name}.get_ptr() != nullptr) amp_tensors_vector.push_back({{ *({name}.get_ptr()) }});\n"
)
amp_autocast_optional_list.append(
f"auto NEW_{name} = {name}.is_initialized() ? egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name) : {name};\n"
f"auto NEW_{name} = ({name}.get_ptr() != nullptr) ? paddle::make_optional<const paddle::experimental::Tensor&>(egr::EagerAmpAutoCast(\"{name}\", *({name}.get_ptr()), amp_dst_dtype, op_name)) : {name};\n"
)
else:
if inplace_map and name in inplace_map.keys():
Expand Down Expand Up @@ -895,7 +895,8 @@ def GenerateForwardDefinition(self, is_inplaced):
amp_tensors_vector_optional_list)
amp_get_dst_dtype_str = f"auto amp_dst_dtype = egr::GetAmpDestDtype(op_name, amp_tensors_vector);\n"
amp_autocast_list_str = " ".join(
amp_autocast_list) + " ".join(amp_autocast_optional_list)
amp_autocast_list) + " " + " ".join(
amp_autocast_optional_list)
amp_inputs_call_args_str = ", ".join(amp_inputs_call_list)
amp_call_str = f"return {forward_function_name}({amp_inputs_call_args_str});"
if is_inplaced or (forward_api_name == "cast"):
Expand Down

1 comment on commit 1e00955

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.