Skip to content

Commit

Permalink
refine performance with fast_divmod
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesLim-sy committed Dec 9, 2022
1 parent 945bf4b commit da6ba39
Showing 1 changed file with 100 additions and 45 deletions.
145 changes: 100 additions & 45 deletions paddle/phi/kernels/gpu/stack_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,93 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/fast_divmod.h"

namespace phi {

template <typename IndexT>
struct DivmodWarpper {
public:
__host__ void SetDivden(IndexT dividen) {
divmoder = phi::funcs::FastDivMod(dividen);
}
__device__ inline phi::funcs::FastDivMod::DivModT div_mod(IndexT val) {
return divmoder.Divmod(val);
}

private:
phi::funcs::FastDivMod divmoder;
};

template <>
struct DivmodWarpper<int64_t> {
public:
using DivModT = phi::AlignedVector<int64_t, 2>;

__host__ void SetDivden(int64_t dividen) { dividen_ = dividen; }
__device__ inline DivModT div_mod(int64_t val) {
DivModT data;
data[0] = val / dividen_;
data[1] = val - data[0] * dividen_;
return data;
}

private:
int64_t dividen_;
};

constexpr int kWarpperSize = 256;
template <typename T>
struct DataWarpper {
template <typename T, typename IndexT, bool IsDataWarpperd>
struct DataWarpper : public DivmodWarpper<IndexT> {
const T* data[kWarpperSize];
HOSTDEVICE inline const T* operator[](int i) const { return data[i]; }
};

template <typename T, typename IntType, typename WarpT>
__global__ void StackCUDAKernel(WarpT input_ptrs,
IntType split_size,
IntType rows,
IntType cols,
template <typename T, typename IndexT>
struct DataWarpper<T, IndexT, false> : public DivmodWarpper<IndexT> {
T** data;
};

template <typename Context, typename T>
T** PackDataAndTransfer(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
int num) {
std::vector<const T*> x_datas(num);
for (int i = 0; i < num; ++i) {
x_datas[i] = x[i]->data<T>();
}
auto byte_len = num * sizeof(T*);
auto tmp_x_data = paddle::memory::Alloc(
dev_ctx.GetPlace(),
byte_len,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
paddle::memory::Copy(dev_ctx.GetPlace(),
tmp_x_data->ptr(),
phi::CPUPlace(),
reinterpret_cast<void*>(x_datas.data()),
byte_len,
dev_ctx.stream());
return reinterpret_cast<T**>(tmp_x_data->ptr());
}

template <typename T, typename IndexT, typename WarpT>
__global__ void StackCUDAKernel(WarpT input_warpper,
IndexT split_size,
IndexT rows,
IndexT cols,
T* __restrict__ output) {
IntType grid_x = static_cast<IntType>(blockIdx.x) * blockDim.x + threadIdx.x;
IntType grid_x_stride = static_cast<IntType>(blockDim.x) * gridDim.x;
IntType grid_y_stride = static_cast<IntType>(blockDim.y) * gridDim.y;
IndexT grid_x = static_cast<IndexT>(blockIdx.x) * blockDim.x + threadIdx.x;
IndexT grid_x_stride = static_cast<IndexT>(blockDim.x) * gridDim.x;
IndexT grid_y_stride = static_cast<IndexT>(blockDim.y) * gridDim.y;

for (; grid_x < cols; grid_x += grid_x_stride) {
IntType grid_y =
static_cast<IntType>(blockIdx.y) * blockDim.y + threadIdx.y;
IndexT grid_y = static_cast<IndexT>(blockIdx.y) * blockDim.y + threadIdx.y;

IntType split = grid_x / split_size;
const T* input_ptr = input_ptrs[split];
IntType col_offset = grid_x % split_size;
auto divmod_rslt = input_warpper.div_mod(grid_x);
const T* input_ptr = input_warpper.data[divmod_rslt[0]];
#pragma unroll
for (; grid_y < rows; grid_y += grid_y_stride) {
output[grid_y * cols + grid_x] =
input_ptr[grid_y * split_size + col_offset];
input_ptr[grid_y * split_size + divmod_rslt[1]];
}
}
}
Expand All @@ -72,47 +128,46 @@ void StackKernel(const Context& dev_ctx,
auto config =
phi::backends::gpu::GetGpuLaunchConfig2D(dev_ctx, out_col, x_row);

#define IMPL_STACK_CUDA_KERNEL(index_t, input_data) \
StackCUDAKernel<T, index_t, decltype(input_data)> \
#define IMPL_STACK_CUDA_KERNEL(index_t, input_warpper) \
StackCUDAKernel<T, index_t, decltype(input_warpper)> \
<<<config.block_per_grid, \
config.thread_per_block, \
0, \
dev_ctx.stream()>>>(input_data, \
dev_ctx.stream()>>>(input_warpper, \
static_cast<index_t>(x_col), \
static_cast<index_t>(x_row), \
static_cast<index_t>(out_col), \
y_data);

if (n <= kWarpperSize) {
DataWarpper<T> data_warpper;
for (auto i = 0; i < n; ++i) {
data_warpper.data[i] = x[i]->data<T>();
}
if (out->numel() < std::numeric_limits<int32_t>::max()) {
if (out->numel() < std::numeric_limits<int32_t>::max()) {
if (n <= kWarpperSize) {
DataWarpper<T, int32_t, true> data_warpper;
for (auto i = 0; i < n; ++i) {
data_warpper.data[i] = x[i]->data<T>();
}
data_warpper.SetDivden(x_col);
IMPL_STACK_CUDA_KERNEL(int32_t, data_warpper);
} else {
IMPL_STACK_CUDA_KERNEL(int64_t, data_warpper);
DataWarpper<T, int32_t, false> data_warpper;
T** pack_ptr = PackDataAndTransfer<Context, T>(dev_ctx, x, n);
data_warpper.data = pack_ptr;
data_warpper.SetDivden(x_col);
IMPL_STACK_CUDA_KERNEL(int32_t, data_warpper);
}
} else {
std::vector<const T*> x_datas(n);
for (int i = 0; i < n; i++) {
x_datas[i] = x[i]->data<T>();
}
auto tmp_x_data = paddle::memory::Alloc(
dev_ctx.GetPlace(),
x_datas.size() * sizeof(T*),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
paddle::memory::Copy(dev_ctx.GetPlace(),
tmp_x_data->ptr(),
phi::CPUPlace(),
reinterpret_cast<void*>(x_datas.data()),
x_datas.size() * sizeof(T*),
dev_ctx.stream());

if (out->numel() < std::numeric_limits<int32_t>::max()) {
IMPL_STACK_CUDA_KERNEL(int32_t, reinterpret_cast<T**>(tmp_x_data->ptr()));
if (n <= kWarpperSize) {
DataWarpper<T, int64_t, true> data_warpper;
for (auto i = 0; i < n; ++i) {
data_warpper.data[i] = x[i]->data<T>();
}
data_warpper.SetDivden(x_col);
IMPL_STACK_CUDA_KERNEL(int64_t, data_warpper);
} else {
IMPL_STACK_CUDA_KERNEL(int64_t, reinterpret_cast<T**>(tmp_x_data->ptr()));
DataWarpper<T, int64_t, false> data_warpper;
T** pack_ptr = PackDataAndTransfer<Context, T>(dev_ctx, x, n);
data_warpper.data = pack_ptr;
data_warpper.SetDivden(x_col);
IMPL_STACK_CUDA_KERNEL(int64_t, data_warpper);
}
}
#undef IMPL_STACK_CUDA_KERNEL
Expand Down

0 comments on commit da6ba39

Please sign in to comment.