Skip to content

Commit

Permalink
Revert "[PIR+CINN]cinn cuda support warp reduce"
Browse files Browse the repository at this point in the history
  • Loading branch information
phlrain committed Jan 19, 2024
1 parent f335680 commit ed126da
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,6 @@ EXPAND_REDUCE_FP16_MACRO(CINN_BLOCK_REDUCE_INTERNAL_MACRO)
#define CINN_BLOCK_REDUCE_INTERNAL_SHM_IMPL(TYPE, value, init_value, cinn_warp_shuffle_internal) \
int warp_id = threadIdx.x / 32; \
TYPE tmp_val = cinn_warp_shuffle_internal(value); \
if ( return_warp ) return tmp_val; \
if (blockDim.x <= 32) { \
return tmp_val; \
} \
Expand All @@ -594,9 +593,9 @@ EXPAND_REDUCE_FP16_MACRO(CINN_BLOCK_REDUCE_INTERNAL_MACRO)
__syncthreads(); \
return shm[0];

#define CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \
__device__ inline DTYPE cinn_block_reduce_##REDUCE_TYPE##_internal_shm(const DTYPE value, DTYPE* shm, bool return_warp = false) { \
CINN_BLOCK_REDUCE_INTERNAL_SHM_IMPL(DTYPE, value, (DTYPE)(INITIAL_VALUE), cinn_warp_shuffle_##REDUCE_TYPE##_internal); \
#define CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \
__device__ inline DTYPE cinn_block_reduce_##REDUCE_TYPE##_internal_shm(const DTYPE value, DTYPE* shm) { \
CINN_BLOCK_REDUCE_INTERNAL_SHM_IMPL(DTYPE, value, (DTYPE)(INITIAL_VALUE), cinn_warp_shuffle_##REDUCE_TYPE##_internal); \
}

EXPAND_REDUCE_INT32_MARCO(CINN_BLOCK_REDUCE_INTERNAL_SHM_MACRO)
Expand Down

0 comments on commit ed126da

Please sign in to comment.