Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine concat_op #8669

Merged
merged 5 commits into from Mar 7, 2018
Merged

Conversation

chengduoZH
Copy link
Contributor

@chengduoZH chengduoZH commented Mar 1, 2018

fix : #8567

Analysis the concat operation

The input is a list of tensors and axis which indicates the concation axis. The shape of input's tensor can be any, but only the dimension of axis can be different.
For example, the input is two tensors.

  • case 1:
    • t_a's shape: [9,2,3,4]
    • t_b's shape:[3,2,3,4]
    • axis = 0,

Obviously, the output's shape is [12,2,3,4]. To simply solve this case, we can reshape t_a to [9, 24] and t_b to [3, 24], finally concate the two tensor longitudinally. The output's shape is [12, 24]. In this case, we only copy two times.

  • case 2:
    • t_a's shape: [9,2,3,4]
    • t_b's shape:[9,3,3,4]
    • axis = 1,

To simply solve this case, we can reshape t_a to [9, 2, 12] and t_b to [9, 3, 12], finally concate the two tensor on the second axis. The output's shape is [9,5,12]. In this case, we should copy 18 times.

  • case 3:
    • t_a's shape: [9,2,3,4]
    • t_b's shape:[9,2,3,3]
    • axis = 3,

Firstly, we reshape t_a to [54, 4] and t_b to [54, 3], finally concate the two tensor horizontally. The output's shape is [54, 7]. This is the worst case, we should copy 108 times.

TODO

  • use one Cuda kernel to complete those copies. All of those cases can be solved by one strategy.

@chengduoZH chengduoZH force-pushed the feature/concat_op branch 2 times, most recently from be8d8ae to 1b09ceb Compare March 3, 2018 08:27
@chengduoZH chengduoZH changed the title [WIP]Refine concat_op Refine concat_op Mar 3, 2018

// get input's cols
std::vector<int64_t> input_cols(input.size());
for (int i = 0; i < num; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/strided_memcpy.h#L53

This functor does the same thing, should we use it or delete the same functor?

// assume the the max size of input is less than 8 and see the performance
// save origin dim
int num = outputs.size();
std::vector<paddle::framework::DDim> origin_dim(num);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete namespace of paddle.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

origin_dim has been removed.

auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());

// computation
for (int k = 0; k < input_rows; ++k) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with above, it is redundant with the stridememcpywithaxis.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This functor doesn't process the GPU data, so it is not redundant with stridememcpywithaxis.
For GPU data, In some case, the functor is slower than stridememcpywithaxis.

std::min((out_cols + block_cols - 1) / block_cols, max_blocks);
int grid_rows =
std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1));
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First of all, this calculation is correct. My only concern is that whether we truly need the information of Multi-Processor.

In fact, I have tested the occupancy in the Titan X(pascal arch) with different grid size, block size, finally find that if the block size >= 256, the grid size use upper bound of input size, I will get near 100% occupancy rate in my test.

In a word, if the block size has been >= 256(8 times of warp size) the GPU occupancy and performance will be ok.

My test code is modified based https://github.com/zchee/cuda-sample/blob/master/0_Simple/simpleOccupancy/simpleOccupancy.cu.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Furthermore, if we really need the attribute of GPU Device, we can store this information in DeviceContext. It is device related and will save a lot of heavy overhead device query.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the grid size use upper bound of input size

Please remind that grid's size is not unlimited. Assuming that the input is x1 and x2 and their shape is (100,1000000), axis is 0, the kernel will fail, you can have a try.

get near 100% occupancy rate

My experience is that, in some case, 100% occupancy doesn't mean the kernel is efficient. My test code is https://gist.github.com/chengduoZH/bc20fa8c12f8f045b74240dcdad41c84 . I also do some experiment and past the result in the comment.

Furthermore, if we really need the attribute of GPU Device, we can store this information in DeviceContext. It is device related and will save a lot of heavy overhead device query.

I agreed that.

@chengduoZH chengduoZH force-pushed the feature/concat_op branch 2 times, most recently from f8ada6b to 9ea6113 Compare March 5, 2018 14:57
@chengduoZH
Copy link
Contributor Author

Some experiments about Concat_op

  • Make a D->D copy
  • First copied to the Host side, and then concat, and finally copied to the GPU side
  • completely use the kernel

There are three cases of Concat

  • axis = 0
  • axis = mid-axis
  • axis = end-axis

Experimental data

  • axis = 0
    Data dimension:
    • [100,100,100] -> (reshape to) 1 x 10 ^ 6
    • [100,100,1000] -> (reshape to) 1 x 10 ^ 7
    • [100,1000,1000] -> (reshape to) 1 x 10 ^ 8
  • axis = 1
    Data dimension:
    • [100,100,100] -> (reshape to) 100 x 10 ^ 4
    • [100,100,1000] -> (reshape to) 100 x 10 ^ 5
    • [100,1000,1000] -> (reshape to) 100 x 10 ^ 6
    • [1000,100,1000] -> (reshape to) 1000 x 10 ^ 5
  • axis = 2
    Data dimension:
    • [100,100,100] -> (reshape to) 10 ^ 4 x 100
    • [100,100,1000] -> (reshape to) 10 ^ 4 x 1000
    • [100,1000,1000] -> (reshape to) 10 ^ 5 x 1000

Experimental method and result:

Experimental method: input two shape-like tensors (x1, x2), reshape the shape of the two tensor into 2 dimensions(the number of all tensors' row is the same) according to the axis, repeat 300 the concat operations and count the average time.
The experimental results are as follows:

input_shape copy(D->D) D->H->concat->D kernel
1 x 10^6 0.074824667 / 0.100071
1 x 10^7 0.472866667 / 0.568678
1 x 10^8 4.530733333 / 9.6112
input_shape copy(D->D) D->H->concat->D kernel
100 x 10^4 0.893866667 3.3531 0.076095
100 x 10^5 1.227483333 34.148 0.58596
1000 x 10^4 10.02413333 35.25233333 0.5465
100 x 10^6 5.195033333 362.9734997 9.7122
1000 x 10^5 10.5971 391.4569667 9.506333333
input_shape copy(D->D) D->H->concat->D kernel
10^4 x 100 87.12066667 3.2949 0.078481
10^4 x 1000 89.871 39.51366667 0.55338
10^5 x 1000 938.5879967 389.6895 8.571266667

Analysis the experiment result

Assuming that the shape of the input data is NxM, when M is far greater than N, the time consumption of directly copying will be less. In other cases, the time of directly using Kernel will be less.


Currently, the implementation of concat use the two strategies, directly copying and using CUDA kernel. When the axis is zero and the number of input is less than 10, concat_OP will directly copy, otherwise, concat_OP will use CUDA kernel.

template class ConcatGradFunctor<platform::CPUDeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CPUDeviceContext, float>;
template class ConcatGradFunctor<platform::CPUDeviceContext, double>;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need a functor? We can just write the code into op.cc/cu

Copy link
Contributor Author

@chengduoZH chengduoZH Mar 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, ConcatFunctor is only used by concat_op, but I think it can be used by SplitOp too. ConcatOp and SplitOp are the two opposite operations. So I define the concat as a functor.

@chengduoZH chengduoZH merged commit 84aea8a into PaddlePaddle:develop Mar 7, 2018
@chengduoZH chengduoZH added this to Done in Performance Tuning Mar 8, 2018
@chengduoZH chengduoZH removed this from Done in Performance Tuning Mar 8, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Speed]concat operator need to be enhanced
3 participants