Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable generating code for a given subgraph. #21126

Merged
merged 10 commits into from
Nov 20, 2019

Conversation

Xreki
Copy link
Contributor

@Xreki Xreki commented Nov 12, 2019

Support generating code for a given subgraph.

Examples of generated codes

  • example 1, generating code for a given list of expressions, each expression represents a forward operation.
extern "C" __global__ void elementwise_kernel_0(int N, float* arg0, float* arg1, float* arg3, float* arg5, float* arg2, float* arg4, float* arg6, float* arg7, float* arg8) {
   for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
       idx < N;
       idx += gridDim.x * blockDim.x) {
     float tmp0 = arg0[idx];
     float tmp1 = arg1[idx];
     float tmp3 = arg3[idx];
     float tmp5 = arg5[idx];
     float tmp2 = tmp0 * tmp1;
     float tmp4 = tmp2 + tmp3;
     float tmp6 = tmp4 - tmp5;
     float tmp7 = real_max(tmp6, 0);
     float tmp8 = 1.0 / (1.0 + real_exp(- tmp7));
     arg2[idx] = tmp2;
     arg4[idx] = tmp4;
     arg6[idx] = tmp6;
     arg7[idx] = tmp7;
     arg8[idx] = tmp8;
   }
 }
  • example 2, generating code for a given list of expressions, each expression represents a backward operation.
extern "C" __global__ void elementwise_grad_kernel_0(int N, float* arg0, float* arg1, float* arg2, float* arg7, float* arg4, float* arg5, float* arg6) {
   for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
       idx < N;
       idx += gridDim.x * blockDim.x) {
     float tmp0 = arg0[idx];
     float tmp1 = arg1[idx];
     float tmp2 = arg2[idx];
     float tmp7 = arg7[idx];
     float tmp6 = tmp2 > 0 ? tmp7 : 0;
     float tmp4 = tmp6 * tmp1;
     float tmp5 = tmp6 * tmp0;
     arg4[idx] = tmp4;
     arg5[idx] = tmp5;
     arg6[idx] = tmp6;
   }
 }
  • example 3, generating code for a given subgraph, each expression represents a forward operation.
extern "C" __global__ void elementwise_kernel_1(int N, float* arg0, float* arg1, float* arg2, float* arg3, float* arg4, float* arg5, float* arg6, float* arg7, float* arg8) {
   for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
       idx < N;
       idx += gridDim.x * blockDim.x) {
     float tmp0 = arg0[idx];
     float tmp1 = arg1[idx];
     float tmp2 = arg2[idx];
     float tmp3 = arg3[idx];
     float tmp4 = 1.0 / (1.0 + real_exp(- tmp0));
     float tmp7 = tmp4 * tmp1;
     float tmp5 = 2.0 / (1.0 + real_exp(-2.0 * tmp2)) - 1.0;
     float tmp6 = tmp3 * tmp5;
     float tmp8 = tmp7 + tmp6;
     arg4[idx] = tmp4;
     arg5[idx] = tmp5;
     arg6[idx] = tmp6;
     arg7[idx] = tmp7;
     arg8[idx] = tmp8;
   }
 }
  • example 4, generating code for a given subgraph, each expression represents a backward operation.
extern "C" __global__ void elementwise_grad_kernel_1(int N, float* arg0, float* arg1, float* arg2, float* arg3, float* arg4, float* arg5, float* arg6, float* arg7, float* arg8, float* arg9, float* arg10, float* arg11, float* arg12, float* arg13, float* arg14, float* arg15, float* arg16, float* arg17) {
   for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
       idx < N;
       idx += gridDim.x * blockDim.x) {
     float tmp2 = arg2[idx];
     float tmp4 = arg4[idx];
     float tmp5 = arg5[idx];
     float tmp6 = arg6[idx];
     float tmp7 = arg7[idx];
     float tmp11 = tmp2;
     float tmp10 = tmp2;
     float tmp16 = tmp10 * tmp5;
     float tmp12 = tmp10 * tmp4;
     float tmp13 = tmp11 * tmp6;
     float tmp17 = tmp11 * tmp7;
     float tmp15 = tmp12 * (1.0 - tmp5 * tmp5);
     float tmp14 = tmp13 * tmp7 * (1.0 - tmp7);
     arg10[idx] = tmp10;
     arg11[idx] = tmp11;
     arg12[idx] = tmp12;
     arg13[idx] = tmp13;
     arg14[idx] = tmp14;
     arg15[idx] = tmp15;
     arg16[idx] = tmp16;
     arg17[idx] = tmp17;
   }
 }

@Xreki
Copy link
Contributor Author

Xreki commented Nov 14, 2019

  • 这个PR解决了Support generating code for grad_op #21066 (comment) 提到的第1个问题,即通过拆分load、computation、store代码,避免输入数据多次从memory中加载,以及中间计算结果先写入memory、再从memory加载出来的问题。
  • Support generating code for grad_op #21066 (comment) 中提到的第3个问题,在example 2中也得到验证。example 2中可以手动设置input_ids,因为可以手动将一些用不到id设置成-1。example 4中的input_ids是从subgraph提取到的,当前的单测中构造后向网络比较简单,会为前向op中每个变量创建对应的GRAD变量,因此无法验证。

Xreki added a commit to Xreki/Paddle that referenced this pull request Nov 18, 2019
* Enable generating code for a given subgraph.

* Support sorting the subgraph.

* Remove the rearange of expressions because we use the sorted subgraph directly.

* Enable generating code for a subgraph which is composed of grad ops.

* Use expression information to check the accuracy in unittest.

* Separate load and store from computation expressions.
test=develop

* Improve the loading statements in generated codes.
test=develop

* Remove unused arguments from formal list.
test=develop
Copy link
Contributor

@wangchaochaohu wangchaochaohu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Xreki Xreki merged commit 6b1e1f0 into PaddlePaddle:develop Nov 20, 2019
@Xreki Xreki deleted the pass_subgraph_generate branch November 20, 2019 05:51
seiriosPlus pushed a commit to seiriosPlus/Paddle that referenced this pull request Dec 9, 2019
* Enable generating code for a given subgraph.

* Support sorting the subgraph.

* Remove the rearange of expressions because we use the sorted subgraph directly.

* Enable generating code for a subgraph which is composed of grad ops.

* Use expression information to check the accuracy in unittest.

* Separate load and store from computation expressions.
test=develop

* Improve the loading statements in generated codes.
test=develop

* Remove unused arguments from formal list.
test=develop
seiriosPlus pushed a commit to seiriosPlus/Paddle that referenced this pull request Dec 9, 2019
* Enable generating code for a given subgraph.

* Support sorting the subgraph.

* Remove the rearange of expressions because we use the sorted subgraph directly.

* Enable generating code for a subgraph which is composed of grad ops.

* Use expression information to check the accuracy in unittest.

* Separate load and store from computation expressions.
test=develop

* Improve the loading statements in generated codes.
test=develop

* Remove unused arguments from formal list.
test=develop
Xreki added a commit to Xreki/Paddle that referenced this pull request Jan 13, 2020
* Add the first implememtation of fusion_group op PaddlePaddle#19621 (#3)

* Add the dynamic load of nvrtc, and support runtime compiling of CUDA kernel using nvrtc.
test=develop

* Call CUDA driver api to launch the kernel compiled by nvrtc.
test=develop

* Disable for mac and windows.
test=develop

* Refine the codes to support manually specified num_threads and workload_per_thread.
test=develop

* Refine the CUDA kernel to support large dims.
test=develop

* Add DeviceCodePool to manage all device codes.

* Add the first implementation fusion_group op.

* Add unit-test for fusion_group op.

* Add the check of result.

* Add the check of nvrtc in unit-test.
test=develop

* Add comment to explain the inputs, outputs and features of fusion_group op.
test=develop

* Disable fusion_group op for mac and windows.
test=develop

* Make the compiling of device code return status instead of hanging up.
test=develop

* Add the check of whether there is CUDA driver library, and do not core dump when failing to call the CUDA driver API.

* Unify fusion_group_op's input and output names.
test=develop

* Add the check of CUDA driver library in unittest.
test=develop

* Enable generating code for a given subgraph. PaddlePaddle#21126 (#4)

* Enable generating code for a given subgraph.

* Support sorting the subgraph.

* Remove the rearange of expressions because we use the sorted subgraph directly.

* Enable generating code for a subgraph which is composed of grad ops.

* Use expression information to check the accuracy in unittest.

* Separate load and store from computation expressions.
test=develop

* Improve the loading statements in generated codes.
test=develop

* Remove unused arguments from formal list.
test=develop

* Enable the detection of subgraph of grad ops.

* Generate code for detected subgraph in fusion_group_pass.

* Add an option in BuildStrategy to enable fusion_group_pass and add unittest.
test=develop

* Fix a bug when checking whether the shape of all inputs are the same.

* Add debug information.

* Remove subgraph_detector from inference/analysis to the common framework/ir directory. (#5)

test=develop

* Call subgraph_detector in fusion_group pass.
test=develop

* Disable fusion_group when WITH_GPU is OFF.
test=develop

* Refine all PADDLE_ENFORCE message.
test=develop

* Fix the case that some inputs are not defined in grad ops, and set op_role for fused op.
test=develop

* add backward gradient computation for op argsort (PaddlePaddle#22203)

* add backward gradient computation for op argsort test=developo

* use pre-commit test=develop

* fix the bug of profile update (PaddlePaddle#22207)

* fix the bug of profile update test=develop

*  add NotImplementedError for multi optimizers (PaddlePaddle#22181)

* add NotImplementedError for multi optimizers used on multi-places . test=develop

* assert error only if num_devices>1. test=develop

* set test_optimizer_in_control_flow in CMakeLists for using multi-GPU.test=develop

* support fluid-lite subgraph run resnet test=develop (PaddlePaddle#22191)

- 添加了fluid-lite子图方式运行resnet的单测
- 修改了依赖Lite的git commit id

* fix bug fot test_dygraph_mnist_fp16.py, test=develop (PaddlePaddle#22222)

* Check dygraph weight name (PaddlePaddle#22140)

* add parameter check; test=develop

* change parameter name checker in dygraph guard; test=develop

* fix test layers error; test=develop

* revert some code to develop; test=develop

* fix exampel error; test=develop

* fix comment error; test=develop

* fix comment error; test=develop

* only import used test case and function(PaddlePaddle#22208)

Co-authored-by: FlyingQianMM <245467267@qq.com>
Co-authored-by: wangchaochaohu <wangchao66@baidu.com>
Co-authored-by: liym27 <33742067+liym27@users.noreply.github.com>
Co-authored-by: Wilber <jiweibo1028@outlook.com>
Co-authored-by: zhongpu <2013000149@qq.com>
Co-authored-by: hong <43953930+phlrain@users.noreply.github.com>
Co-authored-by: Zhang Ting <709968123@qq.com>
Xreki added a commit that referenced this pull request Feb 7, 2020
* Add the first implememtation of fusion_group op #19621 (#3)

* Add the dynamic load of nvrtc, and support runtime compiling of CUDA kernel using nvrtc.
test=develop

* Call CUDA driver api to launch the kernel compiled by nvrtc.
test=develop

* Disable for mac and windows.
test=develop

* Refine the codes to support manually specified num_threads and workload_per_thread.
test=develop

* Refine the CUDA kernel to support large dims.
test=develop

* Add DeviceCodePool to manage all device codes.

* Add the first implementation fusion_group op.

* Add unit-test for fusion_group op.

* Add the check of result.

* Add the check of nvrtc in unit-test.
test=develop

* Add comment to explain the inputs, outputs and features of fusion_group op.
test=develop

* Disable fusion_group op for mac and windows.
test=develop

* Make the compiling of device code return status instead of hanging up.
test=develop

* Add the check of whether there is CUDA driver library, and do not core dump when failing to call the CUDA driver API.

* Unify fusion_group_op's input and output names.
test=develop

* Add the check of CUDA driver library in unittest.
test=develop

* Enable generating code for a given subgraph. #21126 (#4)

* Enable generating code for a given subgraph.

* Support sorting the subgraph.

* Remove the rearange of expressions because we use the sorted subgraph directly.

* Enable generating code for a subgraph which is composed of grad ops.

* Use expression information to check the accuracy in unittest.

* Separate load and store from computation expressions.
test=develop

* Improve the loading statements in generated codes.
test=develop

* Remove unused arguments from formal list.
test=develop

* Enable the detection of subgraph of grad ops.

* Generate code for detected subgraph in fusion_group_pass.

* Add an option in BuildStrategy to enable fusion_group_pass and add unittest.
test=develop

* Fix a bug when checking whether the shape of all inputs are the same.

* Add debug information.

* Remove subgraph_detector from inference/analysis to the common framework/ir directory. (#5)

test=develop

* Call subgraph_detector in fusion_group pass.
test=develop

* Disable fusion_group when WITH_GPU is OFF.
test=develop

* Refine all PADDLE_ENFORCE message.
test=develop

* Fix the case that some inputs are not defined in grad ops, and set op_role for fused op.
test=develop

* Follow review comments.
test=develop
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants