Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dygraph API] Fix merged_momentum, provide actual inplace operations … #59161

Merged
merged 1 commit into from
Nov 22, 2023

Conversation

RuohengMa
Copy link
Contributor

@RuohengMa RuohengMa commented Nov 20, 2023

PR types

Bug fixes

PR changes

APIs

Description

在维持#58204功能的前提下修复遗留的问题,使fallback到CPU后,merged_momentum kernel里输入输出算子在CPU上的地址是一样的

以merged_momentum为例:
output部分:
修复前:

  std::tuple<std::vector<Tensor>&, std::vector<Tensor>&, paddle::optional<std::vector<Tensor>>&> api_output{param, velocity, master_param};
  auto kernel_out_0 = SetInplaceVectorKernelOutput(param.size(), &std::get<0>(api_output));
  if (kernel_result.has_fallback_cpu) {
    TransDataBackend(kernel_out_0, actual_kernel_backend, kernel_out_0);
  }
  auto kernel_out_1 = SetInplaceVectorKernelOutput(param.size(), &std::get<1>(api_output));
  if (kernel_result.has_fallback_cpu) {
    TransDataBackend(kernel_out_1, actual_kernel_backend, kernel_out_1);
  }
  auto kernel_out_2 = SetInplaceOptionalVectorKernelOutput(param.size(), std::get<2>(api_output));
  if (kernel_result.has_fallback_cpu) {
    TransDataBackend(kernel_out_2, actual_kernel_backend, kernel_out_2);
  }

修复后:

  std::tuple<std::vector<Tensor>&, std::vector<Tensor>&, paddle::optional<std::vector<Tensor>>&> api_output{param, velocity, master_param};
  auto kernel_out_0 = SetInplaceVectorKernelOutput(param.size(), &std::get<0>(api_output));
  if (kernel_result.has_fallback_cpu) {
    for (size_t i = 0; i < kernel_out_0.size(); ++i) {
      kernel_out_0[i] = const_cast<phi::DenseTensor*>(input_param[i]);
    }
  }
  auto kernel_out_1 = SetInplaceVectorKernelOutput(param.size(), &std::get<1>(api_output));
  if (kernel_result.has_fallback_cpu) {
    for (size_t i = 0; i < kernel_out_1.size(); ++i) {
      kernel_out_1[i] = const_cast<phi::DenseTensor*>(input_velocity[i]);
    }
  }
  auto kernel_out_2 = SetInplaceOptionalVectorKernelOutput(param.size(), std::get<2>(api_output));
  if (kernel_result.has_fallback_cpu) {
    for (size_t i = 0; i < kernel_out_2.size(); ++i) {
      kernel_out_2[i] = const_cast<phi::DenseTensor*>(input_master_param->at(i));
    }
  }

return前的copy back部分:
修复前:

  if (kernel_result.has_fallback_cpu) {

    TransDataBackend(kernel_out_0, kernel_backend, kernel_out_0);
    TransDataBackend(kernel_out_1, kernel_backend, kernel_out_1);
    TransDataBackend(kernel_out_2, kernel_backend, kernel_out_2);

  }

修复后:

  if (kernel_result.has_fallback_cpu) {

    TransDataBackend(kernel_out_0, kernel_backend, kernel_out_0);
    for (size_t i = 0; i < param.size(); ++i) {
      auto target_ptr = static_cast<phi::DenseTensor*>(param.at(i).impl().get());
      *target_ptr = *kernel_out_0.at(i);
    }
    TransDataBackend(kernel_out_1, kernel_backend, kernel_out_1);
    for (size_t i = 0; i < velocity.size(); ++i) {
      auto target_ptr = static_cast<phi::DenseTensor*>(velocity.at(i).impl().get());
      *target_ptr = *kernel_out_1.at(i);
    }
    TransDataBackend(kernel_out_2, kernel_backend, kernel_out_2);
    if (master_param) {
      for (size_t i = 0; i < master_param->size(); ++i) {
        auto target_ptr = static_cast<phi::DenseTensor*>(master_param->at(i).impl().get());
        *target_ptr = *kernel_out_2.at(i);
      }
    }

  }

Copy link

paddle-bot bot commented Nov 20, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@CLAassistant
Copy link

CLAassistant commented Nov 20, 2023

CLA assistant check
All committers have signed the CLA.

@paddle-bot paddle-bot bot added the contributor External developers label Nov 20, 2023
Copy link

paddle-bot bot commented Nov 20, 2023

❌ The PR is not created using PR's template. You can refer to this Demo.
Please use PR's template, it helps save our maintainers' time so that more developers get helped.

Copy link
Contributor

@qili93 qili93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for const_cast

Comment on lines +176 to +177
{code_indent} kernel_out_{output_idx}[i] = const_cast<phi::DenseTensor*>({PREFIX_TENSOR_NAME}{self.inplace_map[self.outputs['names'][output_idx]]}->at(i));
{code_indent} }}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

本身就是inplace,这里为何要再次把input的指针赋值给out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为如果发生了fallback到cpu的情况的话,输入的tensor是会被从xpu copy到cpu上;因为算子是inplace,输入输出的指针应该是同一个,所以需要把kernel_out指向copy到cpu上的数据而不是指向xpu上的数据~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里等号左右的tensor不是同一个吗?

Comment on lines +1275 to +1276
{code_indent} auto target_ptr = static_cast<phi::DenseTensor*>({target_input}->at(i).impl().get());
{code_indent} *target_ptr = *{kernel_out}.at(i);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不是inplace吗,为何又要把输出给了输入

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为fallback到cpu之后,out是在cpu上的,所以需要重新把out的值写回xpu上的input~

Copy link
Contributor

@cqulilujia cqulilujia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@houj04 houj04 merged commit b50313c into PaddlePaddle:develop Nov 22, 2023
28 checks passed
SecretXV pushed a commit to SecretXV/Paddle that referenced this pull request Nov 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants