Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update histogram op for performance optimization, test=develop #24912

Merged
merged 14 commits into from
Sep 30, 2020

Conversation

qili93
Copy link
Contributor

@qili93 qili93 commented Jun 4, 2020

PR types

Performance optimization

PR changes

OPs

Describe

Address review comments in RP #24562

paddle/fluid/operators/histogram_op.cu
1.1 变量名修改,例如bVal => b_val
1.2 kernel使用shared memory优化CudaAtomicAdd

python/paddle/tensor/linalg.py 修改:
2.1 code-block下面要有一个空行
2.2 确认英文文档可以正常预览

python/paddle/fluid/tests/unittests/test_histogram_op.py
3.1 添加报错相关的单测
3.2 单测中,输入数据类型要求是fp64
3.3 增加浮点数的例子,说明浮点数计算规则

其中cuda代码修改之后,cuda kernel的运行时间比较如下

原有的非shared memory方式,input date shape = shape x shape
shape = 512 time = 0.23
shape = 1024 time = 0.86
shape = 4096 time = 3.37
shape = 8192 time = 13.43

当前shared memory方式,input date shape = shape x shape
shape = 512 time = 0.03
shape = 1024 time = 0.16
shape = 4096 time = 0.58
shape = 8192 time = 4.27

Fluid Doc PR: PaddlePaddle/docs#2732

@paddle-bot-old
Copy link

paddle-bot-old bot commented Jun 4, 2020

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

return bin;
IndexType bin = static_cast<int>((input_value - min_value) * nbins /
(max_value - min_value));
IndexType output_index = bin < nbins - 1 ? bin : nbins - 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是计算逻辑变了吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

逻辑没有变,是变量名字和表达方式变了一下

__syncthreads();

for (int i = threadIdx.x; i < nbins; i += blockDim.x) {
paddle::platform::CudaAtomicAdd(&output[i], buf_hist[i]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里实现的是不是有点问题? buf_hist[i]并没有用到?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我是根据这个文档进行的kernel优化 http://www.cudahandbook.com/wp-content/uploads/2015/03/Histograms.pdf

原来的实现方式是Global memory的方式,所以只要CudaAtomicAdd(&output[outputIdx], 1)就可以了,就是类似这个代码

image

优化后的实现是Shared Memory的方式,是Per-Block计算的,所以在每个block里面要+1,之后还需要吧每个block的计算结果再加起来才能得到正确结果,所以需要CudaAtomicAdd(&output[i], buf_hist[i]),不然得到的ouput的结果是错的,就是类似文档中的这个代码

image

这个两种实现方式的性能差异文档中分别用蓝色线和红色线表示

第一种Global Memory的就是Per Grid是蓝色线,性能较差;
第二种Per Block的是红色的线,性能比蓝色的线要好很多

image

我在实际测试之后发现第二种的实现方式的确在性能上有所提升,就是PR描述里面写的对比结果

原有的非shared memory方式,input date shape = shape x shape
shape = 512 time = 0.23
shape = 1024 time = 0.86
shape = 4096 time = 3.37
shape = 8192 time = 13.43

当前shared memory方式,input date shape = shape x shape
shape = 512 time = 0.03
shape = 1024 time = 0.16
shape = 4096 time = 0.58
shape = 8192 time = 4.27

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我理解你的实现意图是通过share memory 优化性能,但是我没看到你对share memory那块内存的操作

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码改过来的时候写错了,当前用的还是global memeory的方式,所以最后结果是对的,但是性能还是差的。
改了一下cu文件第57行的paddle::platform::CudaAtomicAdd(&output[output_index], 1);

paddle::platform::CudaAtomicAdd(&output[output_index], 1);的结果是
Data Shape is: <4196>
nbins=<8392> minval=<0> maxval=<8392>
Elapsed Time of Grid Kernel is: <12.61>
Elapsed Time of Block Kernel is: <13.59>

改成paddle::platform::CudaAtomicAdd(& buf_hist[output_index], 1);之后的结果是
Data Shape is: <4196>
nbins=<8392> minval=<0> maxval=<8392>
Elapsed Time of Grid Kernel is: <12.59>
Elapsed Time of Block Kernel is: <1.18>

wangchaochaohu
wangchaochaohu previously approved these changes Sep 23, 2020
Copy link
Contributor

@wangchaochaohu wangchaochaohu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

.. code-block:: python

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

示例代码需要改一下哈 880 和 884都不需要写了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

print(np.array(res[0])) # [0,3,0,2,1]

Code Example 2:
Examples:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面的描述,Variable->Tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@qili93 qili93 closed this Sep 30, 2020
@qili93 qili93 reopened this Sep 30, 2020
@PaddlePaddle PaddlePaddle locked as off-topic and limited conversation to collaborators Sep 30, 2020
@PaddlePaddle PaddlePaddle unlocked this conversation Sep 30, 2020
Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wangchaochaohu wangchaochaohu merged commit f373269 into PaddlePaddle:develop Sep 30, 2020
@qili93 qili93 deleted the histogram_op_comments branch September 30, 2020 05:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants