From da6ba39d72ffcb608614bb843b393c633feb2787 Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Thu, 8 Dec 2022 18:10:07 +0000 Subject: [PATCH] refine performance with fast_divmod --- paddle/phi/kernels/gpu/stack_kernel.cu | 145 +++++++++++++++++-------- 1 file changed, 100 insertions(+), 45 deletions(-) diff --git a/paddle/phi/kernels/gpu/stack_kernel.cu b/paddle/phi/kernels/gpu/stack_kernel.cu index ebac404516098..f44ff9de773c7 100644 --- a/paddle/phi/kernels/gpu/stack_kernel.cu +++ b/paddle/phi/kernels/gpu/stack_kernel.cu @@ -18,37 +18,93 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/fast_divmod.h" namespace phi { +template +struct DivmodWarpper { + public: + __host__ void SetDivden(IndexT dividen) { + divmoder = phi::funcs::FastDivMod(dividen); + } + __device__ inline phi::funcs::FastDivMod::DivModT div_mod(IndexT val) { + return divmoder.Divmod(val); + } + + private: + phi::funcs::FastDivMod divmoder; +}; + +template <> +struct DivmodWarpper { + public: + using DivModT = phi::AlignedVector; + + __host__ void SetDivden(int64_t dividen) { dividen_ = dividen; } + __device__ inline DivModT div_mod(int64_t val) { + DivModT data; + data[0] = val / dividen_; + data[1] = val - data[0] * dividen_; + return data; + } + + private: + int64_t dividen_; +}; + constexpr int kWarpperSize = 256; -template -struct DataWarpper { +template +struct DataWarpper : public DivmodWarpper { const T* data[kWarpperSize]; - HOSTDEVICE inline const T* operator[](int i) const { return data[i]; } }; -template -__global__ void StackCUDAKernel(WarpT input_ptrs, - IntType split_size, - IntType rows, - IntType cols, +template +struct DataWarpper : public DivmodWarpper { + T** data; +}; + +template +T** PackDataAndTransfer(const Context& dev_ctx, + const std::vector& x, + int num) { + std::vector x_datas(num); + for (int i = 0; i < num; ++i) { + x_datas[i] = x[i]->data(); + } + auto byte_len = num * sizeof(T*); + auto tmp_x_data = paddle::memory::Alloc( + dev_ctx.GetPlace(), + byte_len, + phi::Stream(reinterpret_cast(dev_ctx.stream()))); + paddle::memory::Copy(dev_ctx.GetPlace(), + tmp_x_data->ptr(), + phi::CPUPlace(), + reinterpret_cast(x_datas.data()), + byte_len, + dev_ctx.stream()); + return reinterpret_cast(tmp_x_data->ptr()); +} + +template +__global__ void StackCUDAKernel(WarpT input_warpper, + IndexT split_size, + IndexT rows, + IndexT cols, T* __restrict__ output) { - IntType grid_x = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; - IntType grid_x_stride = static_cast(blockDim.x) * gridDim.x; - IntType grid_y_stride = static_cast(blockDim.y) * gridDim.y; + IndexT grid_x = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + IndexT grid_x_stride = static_cast(blockDim.x) * gridDim.x; + IndexT grid_y_stride = static_cast(blockDim.y) * gridDim.y; for (; grid_x < cols; grid_x += grid_x_stride) { - IntType grid_y = - static_cast(blockIdx.y) * blockDim.y + threadIdx.y; + IndexT grid_y = static_cast(blockIdx.y) * blockDim.y + threadIdx.y; - IntType split = grid_x / split_size; - const T* input_ptr = input_ptrs[split]; - IntType col_offset = grid_x % split_size; + auto divmod_rslt = input_warpper.div_mod(grid_x); + const T* input_ptr = input_warpper.data[divmod_rslt[0]]; #pragma unroll for (; grid_y < rows; grid_y += grid_y_stride) { output[grid_y * cols + grid_x] = - input_ptr[grid_y * split_size + col_offset]; + input_ptr[grid_y * split_size + divmod_rslt[1]]; } } } @@ -72,47 +128,46 @@ void StackKernel(const Context& dev_ctx, auto config = phi::backends::gpu::GetGpuLaunchConfig2D(dev_ctx, out_col, x_row); -#define IMPL_STACK_CUDA_KERNEL(index_t, input_data) \ - StackCUDAKernel \ +#define IMPL_STACK_CUDA_KERNEL(index_t, input_warpper) \ + StackCUDAKernel \ <<>>(input_data, \ + dev_ctx.stream()>>>(input_warpper, \ static_cast(x_col), \ static_cast(x_row), \ static_cast(out_col), \ y_data); - if (n <= kWarpperSize) { - DataWarpper data_warpper; - for (auto i = 0; i < n; ++i) { - data_warpper.data[i] = x[i]->data(); - } - if (out->numel() < std::numeric_limits::max()) { + if (out->numel() < std::numeric_limits::max()) { + if (n <= kWarpperSize) { + DataWarpper data_warpper; + for (auto i = 0; i < n; ++i) { + data_warpper.data[i] = x[i]->data(); + } + data_warpper.SetDivden(x_col); IMPL_STACK_CUDA_KERNEL(int32_t, data_warpper); } else { - IMPL_STACK_CUDA_KERNEL(int64_t, data_warpper); + DataWarpper data_warpper; + T** pack_ptr = PackDataAndTransfer(dev_ctx, x, n); + data_warpper.data = pack_ptr; + data_warpper.SetDivden(x_col); + IMPL_STACK_CUDA_KERNEL(int32_t, data_warpper); } } else { - std::vector x_datas(n); - for (int i = 0; i < n; i++) { - x_datas[i] = x[i]->data(); - } - auto tmp_x_data = paddle::memory::Alloc( - dev_ctx.GetPlace(), - x_datas.size() * sizeof(T*), - phi::Stream(reinterpret_cast(dev_ctx.stream()))); - paddle::memory::Copy(dev_ctx.GetPlace(), - tmp_x_data->ptr(), - phi::CPUPlace(), - reinterpret_cast(x_datas.data()), - x_datas.size() * sizeof(T*), - dev_ctx.stream()); - - if (out->numel() < std::numeric_limits::max()) { - IMPL_STACK_CUDA_KERNEL(int32_t, reinterpret_cast(tmp_x_data->ptr())); + if (n <= kWarpperSize) { + DataWarpper data_warpper; + for (auto i = 0; i < n; ++i) { + data_warpper.data[i] = x[i]->data(); + } + data_warpper.SetDivden(x_col); + IMPL_STACK_CUDA_KERNEL(int64_t, data_warpper); } else { - IMPL_STACK_CUDA_KERNEL(int64_t, reinterpret_cast(tmp_x_data->ptr())); + DataWarpper data_warpper; + T** pack_ptr = PackDataAndTransfer(dev_ctx, x, n); + data_warpper.data = pack_ptr; + data_warpper.SetDivden(x_col); + IMPL_STACK_CUDA_KERNEL(int64_t, data_warpper); } } #undef IMPL_STACK_CUDA_KERNEL