diff --git a/oneflow/core/cuda/softmax.cuh b/oneflow/core/cuda/softmax.cuh index 77c2ff296d7..ff752c2e0a1 100644 --- a/oneflow/core/cuda/softmax.cuh +++ b/oneflow/core/cuda/softmax.cuh @@ -38,9 +38,9 @@ struct MaxOp { __device__ __forceinline__ T operator()(const T& a, const T& b) const { return max(a, b); } }; -template typename ReductionOp, typename T> +template typename ReductionOp, typename T, int thread_group_width = kWarpSize> __inline__ __device__ T WarpAllReduce(T val) { - for (int mask = kWarpSize / 2; mask > 0; mask /= 2) { + for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { val = ReductionOp()(val, __shfl_xor_sync(0xffffffff, val, mask)); } return val; @@ -115,31 +115,18 @@ inline int GetNumBlocks(int64_t block_size, int64_t max_blocks, int64_t waves) { } template -struct GetComputeType { +struct DefaultComputeType { using type = T; }; template<> -struct GetComputeType { +struct DefaultComputeType { using type = float; }; template -struct GetPackType; - -template -struct GetPackType { - using type = T; -}; - -template<> -struct GetPackType { - using type = half2; -}; - -template<> -struct GetPackType { - using type = char2; +struct GetPackType { + using type = typename std::aligned_storage::type; }; template @@ -155,11 +142,11 @@ union Pack { T elem[N]; }; -template -struct DirectFetch { - DirectFetch(const SRC* src, int64_t row_size) : src(src), row_size(row_size) {} - template - __device__ void fetch(DST* dst, int64_t row, int64_t col) const { +template +struct DirectLoad { + DirectLoad(const SRC* src, int64_t row_size) : src(src), row_size(row_size) {} + template + __device__ void load(DST* dst, int64_t row, int64_t col) const { Pack pack; const int64_t offset = row * row_size + col; pack.storage = *reinterpret_cast*>(src + offset); @@ -170,10 +157,10 @@ struct DirectFetch { int64_t row_size; }; -template +template struct DirectStore { DirectStore(DST* dst, int64_t row_size) : dst(dst), row_size(row_size) {} - template + template __device__ void store(const SRC* src, int64_t row, int64_t col) { Pack pack; const int64_t offset = row * row_size + col; @@ -185,91 +172,135 @@ struct DirectStore { int64_t row_size; }; -template -__global__ void SoftmaxWarpImpl(FETCH fetch, STORE store, const int64_t rows, const int64_t cols) { +template +__global__ void SoftmaxWarpImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols) { static_assert(cols_per_thread % pack_size == 0, ""); + static_assert(thread_group_width <= kWarpSize, ""); + static_assert(kWarpSize % thread_group_width == 0, ""); constexpr int num_packs = cols_per_thread / pack_size; - assert(cols <= cols_per_thread * kWarpSize); - using ComputeType = typename GetComputeType::type; - ComputeType buf[cols_per_thread]; - const int global_warp_id = blockIdx.x * blockDim.y + threadIdx.y; - const int num_global_warp = gridDim.x * blockDim.y; + assert(cols <= cols_per_thread * thread_group_width); + ComputeType buf[rows_per_access][cols_per_thread]; + const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y; + const int num_global_thread_group = gridDim.x * blockDim.y; const int lane_id = threadIdx.x; - for (int64_t row = global_warp_id; row < rows; row += num_global_warp) { - ComputeType thread_max = -Inf(); + for (int64_t row = global_thread_group_id * rows_per_access; row < rows; + row += num_global_thread_group * rows_per_access) { + ComputeType thread_max[rows_per_access]; +#pragma unroll + for (int row_id = 0; row_id < rows_per_access; ++row_id) { + thread_max[row_id] = -Inf(); + ComputeType* row_buf = buf[row_id]; #pragma unroll - for (int pack_id = 0; pack_id < num_packs; ++pack_id) { - const int col = (pack_id * kWarpSize + lane_id) * pack_size; - if (!padding || col < cols) { - fetch.template fetch(buf + pack_id * pack_size, row, col); + for (int pack_id = 0; pack_id < num_packs; ++pack_id) { + const int col = (pack_id * thread_group_width + lane_id) * pack_size; + if (!padding || col < cols) { + load.template load(row_buf + pack_id * pack_size, row + row_id, col); #pragma unroll - for (int i = 0; i < pack_size; ++i) { - thread_max = max(thread_max, buf[pack_id * pack_size + i]); + for (int i = 0; i < pack_size; ++i) { + thread_max[row_id] = max(thread_max[row_id], row_buf[pack_id * pack_size + i]); + } + } else { +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + row_buf[pack_id * pack_size + i] = -Inf(); + } } - } else { + } + } + ComputeType warp_max[rows_per_access]; #pragma unroll - for (int i = 0; i < pack_size; ++i) { buf[pack_id * pack_size + i] = -Inf(); } + for (int row_id = 0; row_id < rows_per_access; ++row_id) { + warp_max[row_id] = WarpAllReduce(thread_max[row_id]); + } + ComputeType thread_sum[rows_per_access]; +#pragma unroll + for (int row_id = 0; row_id < rows_per_access; ++row_id) { + thread_sum[row_id] = 0; + ComputeType* row_buf = buf[row_id]; +#pragma unroll + for (int i = 0; i < cols_per_thread; ++i) { + row_buf[i] = Exp(row_buf[i] - warp_max[row_id]); + thread_sum[row_id] += row_buf[i]; } } - const ComputeType warp_max = WarpAllReduce(thread_max); - ComputeType thread_sum = 0; + ComputeType warp_sum[rows_per_access]; #pragma unroll - for (int i = 0; i < cols_per_thread; ++i) { - buf[i] = Exp(buf[i] - warp_max); - thread_sum += buf[i]; + for (int row_id = 0; row_id < rows_per_access; ++row_id) { + warp_sum[row_id] = WarpAllReduce(thread_sum[row_id]); } - const ComputeType warp_sum = WarpAllReduce(thread_sum); #pragma unroll - for (int i = 0; i < cols_per_thread; ++i) { buf[i] = Div(buf[i], warp_sum); } + for (int row_id = 0; row_id < rows_per_access; ++row_id) { + ComputeType* row_buf = buf[row_id]; #pragma unroll - for (int i = 0; i < num_packs; ++i) { - const int col = (i * kWarpSize + lane_id) * pack_size; - if (!padding || col < cols) { - store.template store(buf + i * pack_size, row, col); + for (int i = 0; i < cols_per_thread; ++i) { row_buf[i] = Div(row_buf[i], warp_sum[row_id]); } +#pragma unroll + for (int i = 0; i < num_packs; ++i) { + const int col = (i * thread_group_width + lane_id) * pack_size; + if (!padding || col < cols) { + store.template store(row_buf + i * pack_size, row + row_id, col); + } } } } } -template -inline void LaunchSoftmaxWarpImpl(cudaStream_t stream, FETCH fetch, STORE store, const int64_t rows, +template +inline void LaunchSoftmaxWarpImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { constexpr int block_size = 128; constexpr int waves = 32; - static_assert(block_size % kWarpSize == 0, ""); - constexpr int rows_per_block = block_size / kWarpSize; - dim3 block_dim(kWarpSize, rows_per_block); + static_assert(block_size % thread_group_width == 0, ""); + constexpr int rows_per_block = block_size / thread_group_width; + dim3 block_dim(thread_group_width, rows_per_block); const int64_t num_blocks = (rows + rows_per_block - 1) / rows_per_block; const int grid_dim_x = GetNumBlocks(block_size, num_blocks, waves); - SoftmaxWarpImpl - <<>>(fetch, store, rows, cols); + SoftmaxWarpImpl + <<>>(load, store, rows, cols); } -template -inline void DispatchSoftmaxWarpImplPadding(cudaStream_t stream, FETCH fetch, STORE store, +template +inline void DispatchSoftmaxWarpImplPadding(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { - if (cols == cols_per_thread * kWarpSize) { - LaunchSoftmaxWarpImpl(stream, fetch, store, - rows, cols); + if (cols == cols_per_thread * thread_group_width) { + LaunchSoftmaxWarpImpl(stream, load, store, rows, cols); } else { - LaunchSoftmaxWarpImpl(stream, fetch, store, - rows, cols); + LaunchSoftmaxWarpImpl(stream, load, store, rows, cols); } } -template +template typename std::enable_if::type DispatchSoftmaxWarpImplCols( - cudaStream_t stream, FETCH fetch, STORE store, const int64_t rows, const int64_t cols) { + cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { if (cols <= 0) { UNIMPLEMENTED(); } -#define DEFINE_ONE_ELIF(col) \ - else if (cols <= (col)*kWarpSize) { \ - DispatchSoftmaxWarpImplPadding(stream, fetch, store, rows, \ - cols); \ +#define DEFINE_ONE_ELIF(thread_group_width) \ + else if (cols <= (thread_group_width)*pack_size) { \ + if (rows % 2 == 0) { \ + DispatchSoftmaxWarpImplPadding(stream, load, store, rows, cols); \ + } else { \ + DispatchSoftmaxWarpImplPadding(stream, load, store, rows, cols); \ + } \ } DEFINE_ONE_ELIF(1) DEFINE_ONE_ELIF(2) + DEFINE_ONE_ELIF(4) + DEFINE_ONE_ELIF(8) + DEFINE_ONE_ELIF(16) + DEFINE_ONE_ELIF(32) +#undef DEFINE_ONE_ELIF +#define DEFINE_ONE_ELIF(col) \ + else if (cols <= (col)*kWarpSize) { \ + DispatchSoftmaxWarpImplPadding( \ + stream, load, store, rows, cols); \ + } + DEFINE_ONE_ELIF(2) DEFINE_ONE_ELIF(3) DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(5) @@ -306,17 +337,33 @@ typename std::enable_if::type DispatchSoftmaxWarpImplCols( } } -template +template typename std::enable_if::type DispatchSoftmaxWarpImplCols( - cudaStream_t stream, FETCH fetch, STORE store, const int64_t rows, const int64_t cols) { + cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { if (cols <= 0) { UNIMPLEMENTED(); } -#define DEFINE_ONE_ELIF(col) \ - else if (cols <= (col)*kWarpSize) { \ - DispatchSoftmaxWarpImplPadding(stream, fetch, store, rows, \ - cols); \ +#define DEFINE_ONE_ELIF(thread_group_width) \ + else if (cols <= (thread_group_width)*pack_size) { \ + if (rows % 2 == 0) { \ + DispatchSoftmaxWarpImplPadding(stream, load, store, rows, cols); \ + } else { \ + DispatchSoftmaxWarpImplPadding(stream, load, store, rows, cols); \ + } \ } + DEFINE_ONE_ELIF(1) DEFINE_ONE_ELIF(2) DEFINE_ONE_ELIF(4) + DEFINE_ONE_ELIF(8) + DEFINE_ONE_ELIF(16) + DEFINE_ONE_ELIF(32) +#undef DEFINE_ONE_ELIF +#define DEFINE_ONE_ELIF(col) \ + else if (cols <= (col)*kWarpSize) { \ + DispatchSoftmaxWarpImplPadding( \ + stream, load, store, rows, cols); \ + } + DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(6) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(10) @@ -337,36 +384,27 @@ typename std::enable_if::type DispatchSoftmaxWarpImplCols( } } -template +template struct DispatchSoftmaxWarpImplPackSize { - void operator()(cudaStream_t stream, FETCH fetch, STORE store, const int64_t rows, - const int64_t cols) { - DispatchSoftmaxWarpImplCols(stream, fetch, store, rows, cols); - } -}; - -template -struct DispatchSoftmaxWarpImplPackSize { - void operator()(cudaStream_t stream, FETCH fetch, STORE store, const int64_t rows, + void operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { - if (cols % 2 == 0 && cols > kWarpSize) { - DispatchSoftmaxWarpImplCols(stream, fetch, store, rows, cols); + if (cols % 2 == 0) { + DispatchSoftmaxWarpImplCols(stream, load, store, rows, cols); } else { - DispatchSoftmaxWarpImplCols(stream, fetch, store, rows, cols); + DispatchSoftmaxWarpImplCols(stream, load, store, rows, cols); } } }; -template -inline void DispatchSoftmaxWarpImpl(cudaStream_t stream, FETCH fetch, STORE store, - const int64_t rows, const int64_t cols) { - DispatchSoftmaxWarpImplPackSize()(stream, fetch, store, rows, cols); +template +inline void DispatchSoftmaxWarpImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, + const int64_t cols) { + DispatchSoftmaxWarpImplPackSize()(stream, load, store, rows, cols); } -template -__global__ void SoftmaxBlockSMemImpl(FETCH fetch, STORE store, const int64_t rows, +template +__global__ void SoftmaxBlockSMemImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols) { - using ComputeType = typename GetComputeType::type; extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[]; auto* buf = reinterpret_cast(shared_buf); const int tid = threadIdx.x; @@ -376,7 +414,7 @@ __global__ void SoftmaxBlockSMemImpl(FETCH fetch, STORE store, const int64_t row ComputeType thread_max = -Inf(); for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { ComputeType pack[pack_size]; - fetch.template fetch(pack, row, pack_id * pack_size); + load.template load(pack, row, pack_id * pack_size); #pragma unroll for (int i = 0; i < pack_size; ++i) { buf[i * num_packs + pack_id] = pack[i]; @@ -398,81 +436,93 @@ __global__ void SoftmaxBlockSMemImpl(FETCH fetch, STORE store, const int64_t row pack[i] = Div(buf[i * num_packs + pack_id], row_sum); thread_max = max(thread_max, pack[i]); } - store.template store(pack, row, pack_id * pack_size); + store.template store(pack, row, pack_id * pack_size); } } } -template -inline void LaunchSoftmaxBlockSMemImpl(cudaStream_t stream, FETCH fetch, STORE store, int smem, +template +inline void LaunchSoftmaxBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store, int smem, const int64_t rows, const int64_t cols) { constexpr int waves = 32; const int grid_dim_x = GetNumBlocks(block_size, rows, waves); - SoftmaxBlockSMemImpl - <<>>(fetch, store, rows, cols); + SoftmaxBlockSMemImpl + <<>>(load, store, rows, cols); } -template -inline bool TryDispatchSoftmaxBlockSMemImplBlockSize(cudaStream_t stream, FETCH fetch, STORE store, +template +inline bool TryDispatchSoftmaxBlockSMemImplBlockSize(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { constexpr int block_size_conf_1 = 128; constexpr int block_size_conf_2 = 256; - const size_t smem = cols * sizeof(typename GetComputeType::type); + constexpr int block_size_conf_3 = 512; + constexpr int block_size_conf_4 = 1024; + const size_t smem = cols * sizeof(ComputeType); int max_active_blocks_conf_1; - int max_active_blocks_conf_2; OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_1, - SoftmaxBlockSMemImpl, block_size_conf_1, - smem)); + SoftmaxBlockSMemImpl, + block_size_conf_1, smem)); if (max_active_blocks_conf_1 <= 0) { return false; } + int max_active_blocks_conf_4; + OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_conf_4, + SoftmaxBlockSMemImpl, + block_size_conf_4, smem)); + if (max_active_blocks_conf_4 == max_active_blocks_conf_1) { + LaunchSoftmaxBlockSMemImpl( + stream, load, store, smem, rows, cols); + return true; + } + int max_active_blocks_conf_3; + OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_conf_3, + SoftmaxBlockSMemImpl, + block_size_conf_3, smem)); + if (max_active_blocks_conf_3 == max_active_blocks_conf_1) { + LaunchSoftmaxBlockSMemImpl( + stream, load, store, smem, rows, cols); + return true; + } + int max_active_blocks_conf_2; OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_2, - SoftmaxBlockSMemImpl, block_size_conf_2, - smem)); + SoftmaxBlockSMemImpl, + block_size_conf_2, smem)); if (max_active_blocks_conf_2 == max_active_blocks_conf_1) { - LaunchSoftmaxBlockSMemImpl(stream, fetch, store, - smem, rows, cols); - } else { - LaunchSoftmaxBlockSMemImpl(stream, fetch, store, - smem, rows, cols); + LaunchSoftmaxBlockSMemImpl( + stream, load, store, smem, rows, cols); + return true; } + LaunchSoftmaxBlockSMemImpl( + stream, load, store, smem, rows, cols); return true; } -template +template struct TryDispatchSoftmaxBlockSMemImplPackSize { - bool operator()(cudaStream_t stream, FETCH fetch, STORE store, const int64_t rows, - const int64_t cols) { - return TryDispatchSoftmaxBlockSMemImplBlockSize(stream, fetch, store, rows, - cols); - } -}; - -template -struct TryDispatchSoftmaxBlockSMemImplPackSize { - bool operator()(cudaStream_t stream, FETCH fetch, STORE store, const int64_t rows, + bool operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { if (cols % 2 == 0) { - return TryDispatchSoftmaxBlockSMemImplBlockSize(stream, fetch, store, - rows, cols); + return TryDispatchSoftmaxBlockSMemImplBlockSize( + stream, load, store, rows, cols); } else { - return TryDispatchSoftmaxBlockSMemImplBlockSize(stream, fetch, store, - rows, cols); + return TryDispatchSoftmaxBlockSMemImplBlockSize( + stream, load, store, rows, cols); } } }; -template -inline bool TryDispatchSoftmaxBlockSMemImpl(cudaStream_t stream, FETCH fetch, STORE store, +template +inline bool TryDispatchSoftmaxBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { - return TryDispatchSoftmaxBlockSMemImplPackSize()(stream, fetch, store, rows, - cols); + return TryDispatchSoftmaxBlockSMemImplPackSize()(stream, load, store, + rows, cols); } -template -__global__ void SoftmaxBlockUncachedImpl(FETCH fetch, STORE store, const int64_t rows, +template +__global__ void SoftmaxBlockUncachedImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols) { - using ComputeType = typename GetComputeType::type; const int tid = threadIdx.x; assert(cols % pack_size == 0); const int num_packs = cols / pack_size; @@ -480,7 +530,7 @@ __global__ void SoftmaxBlockUncachedImpl(FETCH fetch, STORE store, const int64_t ComputeType thread_max = -Inf(); for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { ComputeType pack[pack_size]; - fetch.template fetch(pack, row, pack_id * pack_size); + load.template load(pack, row, pack_id * pack_size); #pragma unroll for (int i = 0; i < pack_size; ++i) { thread_max = max(thread_max, pack[i]); } } @@ -488,151 +538,184 @@ __global__ void SoftmaxBlockUncachedImpl(FETCH fetch, STORE store, const int64_t ComputeType thread_sum = 0; for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { ComputeType pack[pack_size]; - fetch.template fetch(pack, row, pack_id * pack_size); + load.template load(pack, row, pack_id * pack_size); #pragma unroll for (int i = 0; i < pack_size; ++i) { thread_sum += Exp(pack[i] - row_max); } } const ComputeType row_sum = BlockAllReduce(thread_sum); for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { ComputeType pack[pack_size]; - fetch.template fetch(pack, row, pack_id * pack_size); + load.template load(pack, row, pack_id * pack_size); #pragma unroll for (int i = 0; i < pack_size; ++i) { pack[i] = Div(Exp(pack[i] - row_max), row_sum); } - store.template store(pack, row, pack_id * pack_size); + store.template store(pack, row, pack_id * pack_size); } } } -template -inline void LaunchSoftmaxBlockUncachedImpl(cudaStream_t stream, FETCH fetch, STORE store, +template +inline void LaunchSoftmaxBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { constexpr int block_size = 1024; constexpr int waves = 32; const int grid_dim_x = GetNumBlocks(block_size, rows, waves); - SoftmaxBlockUncachedImpl - <<>>(fetch, store, rows, cols); + SoftmaxBlockUncachedImpl + <<>>(load, store, rows, cols); } -template +template struct DispatchSoftmaxBlockUncachedImplPackSize { - void operator()(cudaStream_t stream, FETCH fetch, STORE store, const int64_t rows, - const int64_t cols) { - LaunchSoftmaxBlockUncachedImpl(stream, fetch, store, rows, cols); - } -}; - -template -struct DispatchSoftmaxBlockUncachedImplPackSize { - void operator()(cudaStream_t stream, FETCH fetch, STORE store, const int64_t rows, + void operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { if (cols % 2 == 0) { - LaunchSoftmaxBlockUncachedImpl(stream, fetch, store, rows, cols); + LaunchSoftmaxBlockUncachedImpl(stream, load, store, rows, cols); } else { - LaunchSoftmaxBlockUncachedImpl(stream, fetch, store, rows, cols); + LaunchSoftmaxBlockUncachedImpl(stream, load, store, rows, cols); } } }; -template -inline void DispatchSoftmaxBlockUncachedImpl(cudaStream_t stream, FETCH fetch, STORE store, +template +inline void DispatchSoftmaxBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { - return DispatchSoftmaxBlockUncachedImplPackSize()(stream, fetch, store, rows, - cols); + return DispatchSoftmaxBlockUncachedImplPackSize()(stream, load, store, + rows, cols); } -template -inline void DispatchSoftmax(cudaStream_t stream, FETCH fetch, STORE store, const int64_t rows, +template +inline void DispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) { if (cols <= 1024) { - DispatchSoftmaxWarpImpl(stream, fetch, store, rows, cols); - } else if (!TryDispatchSoftmaxBlockSMemImpl(stream, fetch, store, rows, cols)) { - DispatchSoftmaxBlockUncachedImpl(stream, fetch, store, rows, cols); + DispatchSoftmaxWarpImpl(stream, load, store, rows, cols); + } else if (!TryDispatchSoftmaxBlockSMemImpl(stream, load, store, rows, + cols)) { + DispatchSoftmaxBlockUncachedImpl(stream, load, store, rows, cols); } } -template -__global__ void SoftmaxGradWarpImpl(FETCH_Y fetch_y, FETCH_DY fetch_dy, STORE store, - const int64_t rows, const int64_t cols) { +template +__global__ void SoftmaxGradWarpImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, + const int64_t cols) { static_assert(cols_per_thread % pack_size == 0, ""); constexpr int pack_per_thread = cols_per_thread / pack_size; - assert(cols <= cols_per_thread * kWarpSize); - using ComputeType = typename GetComputeType::type; - ComputeType y_buf[cols_per_thread]; - ComputeType dy_buf[cols_per_thread]; - const int global_warp_id = blockIdx.x * blockDim.y + threadIdx.y; - const int num_global_warp = gridDim.x * blockDim.y; + assert(cols <= cols_per_thread * thread_group_width); + static_assert(thread_group_width <= kWarpSize, ""); + static_assert(kWarpSize % thread_group_width == 0, ""); + ComputeType y_buf[rows_per_access][cols_per_thread]; + ComputeType dy_buf[rows_per_access][cols_per_thread]; + const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y; + const int num_global_thread_group = gridDim.x * blockDim.y; const int lane_id = threadIdx.x; - for (int64_t row = global_warp_id; row < rows; row += num_global_warp) { - ComputeType thread_sum = 0; + for (int64_t row = global_thread_group_id * rows_per_access; row < rows; + row += num_global_thread_group * rows_per_access) { + ComputeType thread_sum[rows_per_access]; #pragma unroll - for (int pack_id = 0; pack_id < pack_per_thread; ++pack_id) { - const int col = (pack_id * kWarpSize + lane_id) * pack_size; - if (!padding || col < cols) { - fetch_y.template fetch(y_buf + pack_id * pack_size, row, col); - fetch_dy.template fetch(dy_buf + pack_id * pack_size, row, col); + for (int row_id = 0; row_id < rows_per_access; ++row_id) { + thread_sum[row_id] = 0; + ComputeType* row_y_buf = y_buf[row_id]; + ComputeType* row_dy_buf = dy_buf[row_id]; #pragma unroll - for (int i = 0; i < pack_size; ++i) { - thread_sum += y_buf[pack_id * pack_size + i] * dy_buf[pack_id * pack_size + i]; + for (int pack_id = 0; pack_id < pack_per_thread; ++pack_id) { + const int col = (pack_id * thread_group_width + lane_id) * pack_size; + if (!padding || col < cols) { + load_y.template load(row_y_buf + pack_id * pack_size, row + row_id, col); + load_dy.template load(row_dy_buf + pack_id * pack_size, row + row_id, col); +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + thread_sum[row_id] += + row_y_buf[pack_id * pack_size + i] * row_dy_buf[pack_id * pack_size + i]; + } } } } - const ComputeType warp_sum = WarpAllReduce(thread_sum); + ComputeType warp_sum[rows_per_access]; +#pragma unroll + for (int row_id = 0; row_id < rows_per_access; ++row_id) { + warp_sum[row_id] = WarpAllReduce(thread_sum[row_id]); + } #pragma unroll - for (int pack_id = 0; pack_id < pack_per_thread; ++pack_id) { - const int col = (pack_id * kWarpSize + lane_id) * pack_size; - if (!padding || col < cols) { - for (int i = 0; i < pack_size; ++i) { - dy_buf[pack_id * pack_size + i] = - (dy_buf[pack_id * pack_size + i] - warp_sum) * y_buf[pack_id * pack_size + i]; + for (int row_id = 0; row_id < rows_per_access; ++row_id) { + ComputeType* row_y_buf = y_buf[row_id]; + ComputeType* row_dy_buf = dy_buf[row_id]; +#pragma unroll + for (int pack_id = 0; pack_id < pack_per_thread; ++pack_id) { + const int col = (pack_id * thread_group_width + lane_id) * pack_size; + if (!padding || col < cols) { + for (int i = 0; i < pack_size; ++i) { + row_dy_buf[pack_id * pack_size + i] = + (row_dy_buf[pack_id * pack_size + i] - warp_sum[row_id]) + * row_y_buf[pack_id * pack_size + i]; + } + store.template store(row_dy_buf + pack_id * pack_size, row + row_id, col); } - store.template store(dy_buf + pack_id * pack_size, row, col); } } } } -template -inline void LaunchSoftmaxGradWarpImpl(cudaStream_t stream, FETCH_Y fetch_y, FETCH_DY fetch_dy, +template +inline void LaunchSoftmaxGradWarpImpl(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { constexpr int block_size = 128; constexpr int waves = 32; - static_assert(block_size % kWarpSize == 0, ""); - constexpr int rows_per_block = block_size / kWarpSize; - dim3 block_dim(kWarpSize, rows_per_block); + static_assert(block_size % thread_group_width == 0, ""); + constexpr int rows_per_block = block_size / thread_group_width; + dim3 block_dim(thread_group_width, rows_per_block); const int64_t num_blocks = (rows + rows_per_block - 1) / rows_per_block; const int grid_dim_x = GetNumBlocks(block_size, num_blocks, waves); - SoftmaxGradWarpImpl - <<>>(fetch_y, fetch_dy, store, rows, cols); + SoftmaxGradWarpImpl + <<>>(load_y, load_dy, store, rows, cols); } -template -inline void DispatchSoftmaxGradWarpImplPadding(cudaStream_t stream, FETCH_Y fetch_y, - FETCH_DY fetch_dy, STORE store, const int64_t rows, +template +inline void DispatchSoftmaxGradWarpImplPadding(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, + STORE store, const int64_t rows, const int64_t cols) { - if (cols == cols_per_thread * kWarpSize) { - LaunchSoftmaxGradWarpImpl( - stream, fetch_y, fetch_dy, store, rows, cols); + if (cols == cols_per_thread * thread_group_width) { + LaunchSoftmaxGradWarpImpl(stream, load_y, load_dy, + store, rows, cols); } else { - LaunchSoftmaxGradWarpImpl( - stream, fetch_y, fetch_dy, store, rows, cols); + LaunchSoftmaxGradWarpImpl(stream, load_y, load_dy, + store, rows, cols); } } -template +template typename std::enable_if::type DispatchSoftmaxGradWarpImplCols( - cudaStream_t stream, FETCH_Y fetch_y, FETCH_DY fetch_dy, STORE store, const int64_t rows, + cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { if (cols <= 0) { UNIMPLEMENTED(); } -#define DEFINE_ONE_ELIF(col) \ - else if (cols <= (col)*kWarpSize) { \ - DispatchSoftmaxGradWarpImplPadding( \ - stream, fetch_y, fetch_dy, store, rows, cols); \ +#define DEFINE_ONE_ELIF(thread_group_width) \ + else if (cols <= (thread_group_width)*pack_size) { \ + if (rows % 2 == 0) { \ + DispatchSoftmaxGradWarpImplPadding( \ + stream, load_y, load_dy, store, rows, cols); \ + } else { \ + DispatchSoftmaxGradWarpImplPadding( \ + stream, load_y, load_dy, store, rows, cols); \ + } \ } DEFINE_ONE_ELIF(1) DEFINE_ONE_ELIF(2) + DEFINE_ONE_ELIF(4) + DEFINE_ONE_ELIF(8) + DEFINE_ONE_ELIF(16) + DEFINE_ONE_ELIF(32) +#undef DEFINE_ONE_ELIF +#define DEFINE_ONE_ELIF(col) \ + else if (cols <= (col)*kWarpSize) { \ + DispatchSoftmaxGradWarpImplPadding(stream, load_y, load_dy, store, rows, cols); \ + } + DEFINE_ONE_ELIF(2) DEFINE_ONE_ELIF(3) DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(5) @@ -669,18 +752,36 @@ typename std::enable_if::type DispatchSoftmaxGradWarpImplC } } -template +template typename std::enable_if::type DispatchSoftmaxGradWarpImplCols( - cudaStream_t stream, FETCH_Y fetch_y, FETCH_DY fetch_dy, STORE store, const int64_t rows, + cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { if (cols <= 0) { UNIMPLEMENTED(); } -#define DEFINE_ONE_ELIF(col) \ - else if (cols <= (col)*kWarpSize) { \ - DispatchSoftmaxGradWarpImplPadding( \ - stream, fetch_y, fetch_dy, store, rows, cols); \ +#define DEFINE_ONE_ELIF(thread_group_width) \ + else if (cols <= (thread_group_width)*pack_size) { \ + if (rows % 2 == 0) { \ + DispatchSoftmaxGradWarpImplPadding( \ + stream, load_y, load_dy, store, rows, cols); \ + } else { \ + DispatchSoftmaxGradWarpImplPadding( \ + stream, load_y, load_dy, store, rows, cols); \ + } \ } + DEFINE_ONE_ELIF(1) DEFINE_ONE_ELIF(2) DEFINE_ONE_ELIF(4) + DEFINE_ONE_ELIF(8) + DEFINE_ONE_ELIF(16) + DEFINE_ONE_ELIF(32) +#undef DEFINE_ONE_ELIF +#define DEFINE_ONE_ELIF(col) \ + else if (cols <= (col)*kWarpSize) { \ + DispatchSoftmaxGradWarpImplPadding(stream, load_y, load_dy, store, rows, cols); \ + } + DEFINE_ONE_ELIF(4) DEFINE_ONE_ELIF(6) DEFINE_ONE_ELIF(8) DEFINE_ONE_ELIF(10) @@ -701,41 +802,31 @@ typename std::enable_if::type DispatchSoftmaxGradWarpImplC } } -template +template struct DispatchSoftmaxGradWarpImplPackSize { - void operator()(cudaStream_t stream, FETCH_Y fetch_y, FETCH_DY fetch_dy, STORE store, - const int64_t rows, const int64_t cols) { - DispatchSoftmaxGradWarpImplCols(stream, fetch_y, fetch_dy, - store, rows, cols); - } -}; - -template -struct DispatchSoftmaxGradWarpImplPackSize { - void operator()(cudaStream_t stream, FETCH_Y fetch_y, FETCH_DY fetch_dy, STORE store, + void operator()(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { - if (cols % 2 == 0 && cols > kWarpSize) { - DispatchSoftmaxGradWarpImplCols(stream, fetch_y, fetch_dy, - store, rows, cols); + if (cols % 2 == 0) { + DispatchSoftmaxGradWarpImplCols( + stream, load_y, load_dy, store, rows, cols); } else { - DispatchSoftmaxGradWarpImplCols(stream, fetch_y, fetch_dy, - store, rows, cols); + DispatchSoftmaxGradWarpImplCols( + stream, load_y, load_dy, store, rows, cols); } } }; -template -inline void DispatchSoftmaxGradWarpImpl(cudaStream_t stream, FETCH_Y fetch_y, FETCH_DY fetch_dy, +template +inline void DispatchSoftmaxGradWarpImpl(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { - DispatchSoftmaxGradWarpImplPackSize()(stream, fetch_y, fetch_dy, - store, rows, cols); + DispatchSoftmaxGradWarpImplPackSize()( + stream, load_y, load_dy, store, rows, cols); } -template -__global__ void SoftmaxGradBlockSMemImpl(FETCH_Y fetch_y, FETCH_DY fetch_dy, STORE store, +__global__ void SoftmaxGradBlockSMemImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { - using ComputeType = typename GetComputeType::type; extern __shared__ __align__(sizeof(double)) unsigned char grad_shared_buf[]; auto* y_buf = reinterpret_cast(grad_shared_buf); auto* dy_buf = y_buf + cols; @@ -747,8 +838,8 @@ __global__ void SoftmaxGradBlockSMemImpl(FETCH_Y fetch_y, FETCH_DY fetch_dy, STO for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { ComputeType y_pack[pack_size]; ComputeType dy_pack[pack_size]; - fetch_y.template fetch(y_pack, row, pack_id * pack_size); - fetch_dy.template fetch(dy_pack, row, pack_id * pack_size); + load_y.template load(y_pack, row, pack_id * pack_size); + load_dy.template load(dy_pack, row, pack_id * pack_size); #pragma unroll for (int i = 0; i < pack_size; ++i) { y_buf[i * num_packs + pack_id] = y_pack[i]; @@ -763,86 +854,101 @@ __global__ void SoftmaxGradBlockSMemImpl(FETCH_Y fetch_y, FETCH_DY fetch_dy, STO for (int i = 0; i < pack_size; ++i) { pack[i] = (dy_buf[i * num_packs + pack_id] - row_sum) * y_buf[i * num_packs + pack_id]; } - store.template store(pack, row, pack_id * pack_size); + store.template store(pack, row, pack_id * pack_size); } } } -template -inline void LaunchSoftmaxGradBlockSMemImpl(cudaStream_t stream, FETCH_Y fetch_y, FETCH_DY fetch_dy, +inline void LaunchSoftmaxGradBlockSMemImpl(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, int smem, const int64_t rows, const int64_t cols) { constexpr int waves = 32; const int grid_dim_x = GetNumBlocks(block_size, rows, waves); - SoftmaxGradBlockSMemImpl - <<>>(fetch_y, fetch_dy, store, rows, cols); + SoftmaxGradBlockSMemImpl + <<>>(load_y, load_dy, store, rows, cols); } -template -inline bool TryDispatchSoftmaxGradBlockSMemImplBlockSize(cudaStream_t stream, FETCH_Y fetch_y, - FETCH_DY fetch_dy, STORE store, +template +inline bool TryDispatchSoftmaxGradBlockSMemImplBlockSize(cudaStream_t stream, LOAD_Y load_y, + LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { constexpr int block_size_conf_1 = 128; constexpr int block_size_conf_2 = 256; - const size_t smem = cols * sizeof(typename GetComputeType::type) * 2; + constexpr int block_size_conf_3 = 512; + constexpr int block_size_conf_4 = 1024; + const size_t smem = cols * sizeof(ComputeType) * 2; int max_active_blocks_conf_1; - int max_active_blocks_conf_2; OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_1, - SoftmaxGradBlockSMemImpl, + SoftmaxGradBlockSMemImpl, block_size_conf_1, smem)); if (max_active_blocks_conf_1 <= 0) { return false; } + int max_active_blocks_conf_4; + OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_conf_4, + SoftmaxGradBlockSMemImpl, + block_size_conf_4, smem)); + if (max_active_blocks_conf_4 == max_active_blocks_conf_1) { + LaunchSoftmaxGradBlockSMemImpl(stream, load_y, load_dy, store, smem, rows, + cols); + return true; + } + int max_active_blocks_conf_3; + OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_conf_3, + SoftmaxGradBlockSMemImpl, + block_size_conf_3, smem)); + if (max_active_blocks_conf_3 == max_active_blocks_conf_1) { + LaunchSoftmaxGradBlockSMemImpl(stream, load_y, load_dy, store, smem, rows, + cols); + return true; + } + int max_active_blocks_conf_2; OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_conf_2, - SoftmaxGradBlockSMemImpl, + SoftmaxGradBlockSMemImpl, block_size_conf_2, smem)); if (max_active_blocks_conf_2 == max_active_blocks_conf_1) { - LaunchSoftmaxGradBlockSMemImpl( - stream, fetch_y, fetch_dy, store, smem, rows, cols); - } else { - LaunchSoftmaxGradBlockSMemImpl( - stream, fetch_y, fetch_dy, store, smem, rows, cols); + LaunchSoftmaxGradBlockSMemImpl(stream, load_y, load_dy, store, smem, rows, + cols); + return true; } + LaunchSoftmaxGradBlockSMemImpl( + stream, load_y, load_dy, store, smem, rows, cols); return true; } -template +template struct TryDispatchSoftmaxGradBlockSMemImplPackSize { - bool operator()(cudaStream_t stream, FETCH_Y fetch_y, FETCH_DY fetch_dy, STORE store, - const int64_t rows, const int64_t cols) { - return TryDispatchSoftmaxGradBlockSMemImplBlockSize( - stream, fetch_y, fetch_dy, store, rows, cols); - } -}; - -template -struct TryDispatchSoftmaxGradBlockSMemImplPackSize { - bool operator()(cudaStream_t stream, FETCH_Y fetch_y, FETCH_DY fetch_dy, STORE store, + bool operator()(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { if (cols % 2 == 0) { - return TryDispatchSoftmaxGradBlockSMemImplBlockSize( - stream, fetch_y, fetch_dy, store, rows, cols); + return TryDispatchSoftmaxGradBlockSMemImplBlockSize( + stream, load_y, load_dy, store, rows, cols); } else { - return TryDispatchSoftmaxGradBlockSMemImplBlockSize( - stream, fetch_y, fetch_dy, store, rows, cols); + return TryDispatchSoftmaxGradBlockSMemImplBlockSize( + stream, load_y, load_dy, store, rows, cols); } } }; -template -inline bool TryDispatchSoftmaxGradBlockSMemImpl(cudaStream_t stream, FETCH_Y fetch_y, - FETCH_DY fetch_dy, STORE store, const int64_t rows, +template +inline bool TryDispatchSoftmaxGradBlockSMemImpl(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, + STORE store, const int64_t rows, const int64_t cols) { - return TryDispatchSoftmaxGradBlockSMemImplPackSize()( - stream, fetch_y, fetch_dy, store, rows, cols); + return TryDispatchSoftmaxGradBlockSMemImplPackSize()( + stream, load_y, load_dy, store, rows, cols); } -template -__global__ void SoftmaxGradBlockUncachedImpl(FETCH_Y fetch_y, FETCH_DY fetch_dy, STORE store, +__global__ void SoftmaxGradBlockUncachedImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { - using ComputeType = typename GetComputeType::type; const int tid = threadIdx.x; assert(cols % pack_size == 0); const int num_packs = cols / pack_size; @@ -851,8 +957,8 @@ __global__ void SoftmaxGradBlockUncachedImpl(FETCH_Y fetch_y, FETCH_DY fetch_dy, for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { ComputeType y_pack[pack_size]; ComputeType dy_pack[pack_size]; - fetch_y.template fetch(y_pack, row, pack_id * pack_size); - fetch_dy.template fetch(dy_pack, row, pack_id * pack_size); + load_y.template load(y_pack, row, pack_id * pack_size); + load_dy.template load(dy_pack, row, pack_id * pack_size); #pragma unroll for (int i = 0; i < pack_size; ++i) { thread_sum += y_pack[i] * dy_pack[i]; } @@ -861,67 +967,58 @@ __global__ void SoftmaxGradBlockUncachedImpl(FETCH_Y fetch_y, FETCH_DY fetch_dy, for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { ComputeType y_pack[pack_size]; ComputeType dy_pack[pack_size]; - fetch_y.template fetch(y_pack, row, pack_id * pack_size); - fetch_dy.template fetch(dy_pack, row, pack_id * pack_size); + load_y.template load(y_pack, row, pack_id * pack_size); + load_dy.template load(dy_pack, row, pack_id * pack_size); #pragma unroll for (int i = 0; i < pack_size; ++i) { dy_pack[i] = (dy_pack[i] - row_sum) * y_pack[i]; } - store.template store(dy_pack, row, pack_id * pack_size); + store.template store(dy_pack, row, pack_id * pack_size); } } } -template -inline void LaunchSoftmaxGradBlockUncachedImpl(cudaStream_t stream, FETCH_Y fetch_y, - FETCH_DY fetch_dy, STORE store, const int64_t rows, +template +inline void LaunchSoftmaxGradBlockUncachedImpl(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, + STORE store, const int64_t rows, const int64_t cols) { constexpr int block_size = 1024; constexpr int waves = 32; const int grid_dim_x = GetNumBlocks(block_size, rows, waves); - SoftmaxGradBlockUncachedImpl - <<>>(fetch_y, fetch_dy, store, rows, cols); + SoftmaxGradBlockUncachedImpl + <<>>(load_y, load_dy, store, rows, cols); } -template +template struct DispatchSoftmaxGradBlockUncachedImplPackSize { - void operator()(cudaStream_t stream, FETCH_Y fetch_y, FETCH_DY fetch_dy, STORE store, - const int64_t rows, const int64_t cols) { - LaunchSoftmaxGradBlockUncachedImpl(stream, fetch_y, fetch_dy, - store, rows, cols); - } -}; - -template -struct DispatchSoftmaxGradBlockUncachedImplPackSize { - void operator()(cudaStream_t stream, FETCH_Y fetch_y, FETCH_DY fetch_dy, STORE store, + void operator()(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { if (cols % 2 == 0 && cols > kWarpSize) { - LaunchSoftmaxGradBlockUncachedImpl( - stream, fetch_y, fetch_dy, store, rows, cols); + LaunchSoftmaxGradBlockUncachedImpl( + stream, load_y, load_dy, store, rows, cols); } else { - LaunchSoftmaxGradBlockUncachedImpl( - stream, fetch_y, fetch_dy, store, rows, cols); + LaunchSoftmaxGradBlockUncachedImpl( + stream, load_y, load_dy, store, rows, cols); } } }; -template -inline void DispatchSoftmaxGradBlockUncachedImpl(cudaStream_t stream, FETCH_Y fetch_y, - FETCH_DY fetch_dy, STORE store, const int64_t rows, +template +inline void DispatchSoftmaxGradBlockUncachedImpl(cudaStream_t stream, LOAD_Y load_y, + LOAD_DY load_dy, STORE store, const int64_t rows, const int64_t cols) { - return DispatchSoftmaxGradBlockUncachedImplPackSize()( - stream, fetch_y, fetch_dy, store, rows, cols); + return DispatchSoftmaxGradBlockUncachedImplPackSize()( + stream, load_y, load_dy, store, rows, cols); } -template -inline void DispatchSoftmaxGrad(cudaStream_t stream, FETCH_Y fetch_y, FETCH_DY fetch_dy, - STORE store, const int64_t rows, const int64_t cols) { +template +inline void DispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, + const int64_t rows, const int64_t cols) { if (cols <= 1024) { - DispatchSoftmaxGradWarpImpl(stream, fetch_y, fetch_dy, store, rows, - cols); - } else if (!TryDispatchSoftmaxGradBlockSMemImpl( - stream, fetch_y, fetch_dy, store, rows, cols)) { - DispatchSoftmaxGradBlockUncachedImpl(stream, fetch_y, fetch_dy, - store, rows, cols); + DispatchSoftmaxGradWarpImpl(stream, load_y, load_dy, store, + rows, cols); + } else if (!TryDispatchSoftmaxGradBlockSMemImpl( + stream, load_y, load_dy, store, rows, cols)) { + DispatchSoftmaxGradBlockUncachedImpl( + stream, load_y, load_dy, store, rows, cols); } } diff --git a/oneflow/python/test/ops/test_softmax.py b/oneflow/python/test/ops/test_softmax.py index 1e33035e639..798327dcf54 100644 --- a/oneflow/python/test/ops/test_softmax.py +++ b/oneflow/python/test/ops/test_softmax.py @@ -105,6 +105,7 @@ def test_softmax_shape(test_case): (10, 20, 30), (10, 20), (10, 60), + (15, 60), (32, 12, 128), (10, 960), (12, 2001), diff --git a/oneflow/user/kernels/fused_tril_scale_softmax_mask_scale_kernel.cu b/oneflow/user/kernels/fused_tril_scale_softmax_mask_scale_kernel.cu index 0a1f4617f51..ff69112c028 100644 --- a/oneflow/user/kernels/fused_tril_scale_softmax_mask_scale_kernel.cu +++ b/oneflow/user/kernels/fused_tril_scale_softmax_mask_scale_kernel.cu @@ -18,23 +18,23 @@ limitations under the License. namespace oneflow { -template -struct TrilScaleFetch { - TrilScaleFetch(const SRC* src, int64_t tril_num_rows, int64_t row_size, int64_t diagonal, - SRC fill, SRC scale) +template +struct TrilScaleLoad { + TrilScaleLoad(const SRC* src, int64_t tril_num_rows, int64_t row_size, int64_t diagonal, SRC fill, + SRC scale) : src(src), tril_num_rows(tril_num_rows), row_size(row_size), diagonal(diagonal), fill(fill), scale(scale) {} - template - __device__ void fetch(DST* dst, int64_t row, int64_t col) { + template + __device__ void load(DST* dst, int64_t row, int64_t col) { int64_t tril_row = row % tril_num_rows; int64_t diagonal_col_id = tril_row + diagonal; - bool need_fetch = (col <= diagonal_col_id); + bool need_load = (col <= diagonal_col_id); cuda::softmax::Pack pack; - if (need_fetch) { + if (need_load) { const int64_t offset = row * row_size + col; pack.storage = *reinterpret_cast*>(src + offset); } @@ -55,11 +55,11 @@ struct TrilScaleFetch { SRC scale; }; -template +template struct MaskAndScaleStore { MaskAndScaleStore(DST* dst, DST* softmax_y, const int8_t* mask, int64_t row_size, DST scale) : dst(dst), softmax_y(softmax_y), mask(mask), row_size(row_size), scale(scale) {} - template + template __device__ void store(const SRC* src, int64_t row, int64_t col) { cuda::softmax::Pack softmax_y_pack; cuda::softmax::Pack dst_pack; @@ -83,12 +83,12 @@ struct MaskAndScaleStore { DST scale; }; -template -struct MaskAndScaleFetch { - MaskAndScaleFetch(const SRC* src, const int8_t* mask, int64_t row_size, SRC scale) +template +struct MaskAndScaleLoad { + MaskAndScaleLoad(const SRC* src, const int8_t* mask, int64_t row_size, SRC scale) : src(src), mask(mask), row_size(row_size), scale(scale) {} - template - __device__ void fetch(DST* dst, int64_t row, int64_t col) const { + template + __device__ void load(DST* dst, int64_t row, int64_t col) const { cuda::softmax::Pack pack; const int64_t offset = row * row_size + col; pack.storage = *reinterpret_cast*>(src + offset); @@ -106,7 +106,7 @@ struct MaskAndScaleFetch { SRC scale; }; -template +template struct TrilScaleStore { TrilScaleStore(DST* dst, int64_t tril_num_rows, int64_t row_size, int64_t diagonal, DST fill, DST scale) @@ -116,7 +116,7 @@ struct TrilScaleStore { diagonal(diagonal), fill(fill), scale(scale) {} - template + template __device__ void store(const SRC* src, int64_t row, int64_t col) { cuda::softmax::Pack pack; const int64_t offset = row * row_size + col; @@ -156,13 +156,15 @@ class FusedTrilScaleSoftmaxMaskScaleKernel final : public user_op::OpKernel { const int64_t cols = x_shape.At(x_shape.NumAxes() - 1); const int64_t rows = x_shape.Count(0, x_shape.NumAxes() - 1); const int64_t tril_num_rows = x_shape.At(x_shape.NumAxes() - 2); - TrilScaleFetch fetch(x->dptr(), tril_num_rows, cols, ctx->Attr("diagonal"), - ctx->Attr("tril_fill_value"), - ctx->Attr("tril_scale_value")); - MaskAndScaleStore store(y->mut_dptr(), softmax_y->mut_dptr(), mask->dptr(), - cols, ctx->Attr("mask_scale_value")); - cuda::softmax::DispatchSoftmax( - ctx->device_ctx()->cuda_stream(), fetch, store, rows, cols); + using ComputeType = typename cuda::softmax::DefaultComputeType::type; + TrilScaleLoad load( + x->dptr(), tril_num_rows, cols, ctx->Attr("diagonal"), + ctx->Attr("tril_fill_value"), ctx->Attr("tril_scale_value")); + MaskAndScaleStore store(y->mut_dptr(), softmax_y->mut_dptr(), + mask->dptr(), cols, + ctx->Attr("mask_scale_value")); + cuda::softmax::DispatchSoftmax( + ctx->device_ctx()->cuda_stream(), load, store, rows, cols); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; @@ -195,14 +197,16 @@ class FusedTrilScaleSoftmaxMaskScaleGradKernel final : public user_op::OpKernel const int64_t cols = dy_shape.At(dy_shape.NumAxes() - 1); const int64_t rows = dy_shape.Count(0, dy_shape.NumAxes() - 1); const int64_t tril_num_rows = dy_shape.At(dy_shape.NumAxes() - 2); - cuda::softmax::DirectFetch fetch_softmax_y(softmax_y->dptr(), cols); - MaskAndScaleFetch fetch_dy(dy->dptr(), mask->dptr(), cols, - ctx->Attr("mask_scale_value")); - TrilScaleStore store(dx->mut_dptr(), tril_num_rows, cols, ctx->Attr("diagonal"), - static_cast(0.0), ctx->Attr("tril_scale_value")); - cuda::softmax::DispatchSoftmaxGrad( - ctx->device_ctx()->cuda_stream(), fetch_softmax_y, fetch_dy, store, rows, cols); + using ComputeType = typename cuda::softmax::DefaultComputeType::type; + cuda::softmax::DirectLoad load_softmax_y(softmax_y->dptr(), cols); + MaskAndScaleLoad load_dy(dy->dptr(), mask->dptr(), cols, + ctx->Attr("mask_scale_value")); + TrilScaleStore store(dx->mut_dptr(), tril_num_rows, cols, + ctx->Attr("diagonal"), static_cast(0.0), + ctx->Attr("tril_scale_value")); + cuda::softmax::DispatchSoftmaxGrad(ctx->device_ctx()->cuda_stream(), + load_softmax_y, load_dy, store, rows, cols); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; diff --git a/oneflow/user/kernels/softmax_kernel.cu b/oneflow/user/kernels/softmax_kernel.cu index 86071664655..838af541831 100644 --- a/oneflow/user/kernels/softmax_kernel.cu +++ b/oneflow/user/kernels/softmax_kernel.cu @@ -31,10 +31,11 @@ class SoftmaxKernel final : public user_op::OpKernel { const ShapeView& in_shape = in->shape(); const int64_t cols = in_shape.At(in_shape.NumAxes() - 1); const int64_t rows = in_shape.Count(0, in_shape.NumAxes() - 1); - cuda::softmax::DirectFetch fetch(in->dptr(), cols); - cuda::softmax::DirectStore store(out->mut_dptr(), cols); - cuda::softmax::DispatchSoftmax( - ctx->device_ctx()->cuda_stream(), fetch, store, rows, cols); + using ComputeType = typename cuda::softmax::DefaultComputeType::type; + cuda::softmax::DirectLoad load(in->dptr(), cols); + cuda::softmax::DirectStore store(out->mut_dptr(), cols); + cuda::softmax::DispatchSoftmax( + ctx->device_ctx()->cuda_stream(), load, store, rows, cols); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; @@ -62,11 +63,13 @@ class SoftmaxGradKernel final : public user_op::OpKernel { user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); const int64_t cols = y->shape().At(y->shape().NumAxes() - 1); const int64_t rows = y->shape().elem_cnt() / cols; - cuda::softmax::DirectFetch fetch_y(y->dptr(), cols); - cuda::softmax::DirectFetch fetch_dy(dy->dptr(), cols); - cuda::softmax::DirectStore store(dx->mut_dptr(), cols); - cuda::softmax::DispatchSoftmaxGrad( - ctx->device_ctx()->cuda_stream(), fetch_y, fetch_dy, store, rows, cols); + using ComputeType = typename cuda::softmax::DefaultComputeType::type; + cuda::softmax::DirectLoad load_y(y->dptr(), cols); + cuda::softmax::DirectLoad load_dy(dy->dptr(), cols); + cuda::softmax::DirectStore store(dx->mut_dptr(), cols); + cuda::softmax::DispatchSoftmaxGrad(ctx->device_ctx()->cuda_stream(), load_y, + load_dy, store, rows, cols); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; diff --git a/oneflow/user/kernels/sparse_softmax_cross_entropy_kernel.cu b/oneflow/user/kernels/sparse_softmax_cross_entropy_kernel.cu index 2730c3644c8..7bdb3cde080 100644 --- a/oneflow/user/kernels/sparse_softmax_cross_entropy_kernel.cu +++ b/oneflow/user/kernels/sparse_softmax_cross_entropy_kernel.cu @@ -24,18 +24,19 @@ namespace { template void ComputeProb(DeviceCtx* ctx, const int64_t row, const int64_t col, const T* in, T* prob) { - cuda::softmax::DirectFetch fetch(in, col); - cuda::softmax::DirectStore store(prob, col); - cuda::softmax::DispatchSoftmax(ctx->cuda_stream(), fetch, - store, row, col); + using ComputeType = typename cuda::softmax::DefaultComputeType::type; + cuda::softmax::DirectLoad load(in, col); + cuda::softmax::DirectStore store(prob, col); + cuda::softmax::DispatchSoftmax( + ctx->cuda_stream(), load, store, row, col); } template<> void ComputeProb(DeviceCtx* ctx, const int64_t row, const int64_t col, const float16* in, float16* prob) { - cuda::softmax::DirectFetch fetch(reinterpret_cast(in), col); - cuda::softmax::DirectStore store(reinterpret_cast(prob), col); - cuda::softmax::DispatchSoftmax(ctx->cuda_stream(), fetch, + cuda::softmax::DirectLoad load(reinterpret_cast(in), col); + cuda::softmax::DirectStore store(reinterpret_cast(prob), col); + cuda::softmax::DispatchSoftmax(ctx->cuda_stream(), load, store, row, col); }