diff --git a/apex/contrib/csrc/groupbn/batch_norm.h b/apex/contrib/csrc/groupbn/batch_norm.h index 90722043b..e52751bce 100644 --- a/apex/contrib/csrc/groupbn/batch_norm.h +++ b/apex/contrib/csrc/groupbn/batch_norm.h @@ -36,7 +36,7 @@ #include "nhwc_batch_norm_kernel.h" #include "cuda_utils.h" #include "c10/macros/Macros.h" - +#include #define VERBOSE_DEFAULT false @@ -626,7 +626,7 @@ class NhwcBatchNorm { // Calculate the expected fwd kernel occupancy, as dictated by shared memory usage. static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) { using namespace at::cuda::utils; - int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float); + int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/at::cuda::warp_size())*ELEMENTS_PER_LDG*sizeof(float); int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes; int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes; return std::min(max_cta_per_sm, occupancy); @@ -635,7 +635,7 @@ class NhwcBatchNorm { // Calculate the expected bwd kernel occupancy, as dictated by shared memory usage. static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) { using namespace at::cuda::utils; - int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float); + int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/at::cuda::warp_size())*ELEMENTS_PER_LDG*sizeof(float); int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes; int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes; return std::min(max_cta_per_sm, occupancy); diff --git a/apex/contrib/csrc/groupbn/batch_norm_add_relu.h b/apex/contrib/csrc/groupbn/batch_norm_add_relu.h index de9428ca7..0481a9408 100644 --- a/apex/contrib/csrc/groupbn/batch_norm_add_relu.h +++ b/apex/contrib/csrc/groupbn/batch_norm_add_relu.h @@ -36,6 +36,7 @@ #include "nhwc_batch_norm_kernel.h" #include "cuda_utils.h" #include "c10/macros/Macros.h" +#include #ifdef USE_ROCM using bitmask_t = uint64_t; @@ -530,7 +531,7 @@ class NhwcBatchNormAddRelu { // Calculate the expected fwd kernel occupancy, as dictated by shared memory usage. static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) { using namespace at::cuda::utils; - int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float); + int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/at::cuda::warp_size())*ELEMENTS_PER_LDG*sizeof(float); int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes; int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes; return std::min(max_cta_per_sm, occupancy); @@ -539,7 +540,7 @@ class NhwcBatchNormAddRelu { // Calculate the expected bwd kernel occupancy, as dictated by shared memory usage. static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) { using namespace at::cuda::utils; - int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float); + int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/at::cuda::warp_size())*ELEMENTS_PER_LDG*sizeof(float); int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes; int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes; return std::min(max_cta_per_sm, occupancy); diff --git a/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h b/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h index 0dd47a340..44ec92688 100644 --- a/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h +++ b/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h @@ -546,14 +546,8 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw, void* params_my_data, void** params_pair_datas, int off, const int magic, const int sync_iters) { - // The size of a warp. -#ifdef USE_ROCM - const int THREADS_PER_WARP = 64; -#else - const int THREADS_PER_WARP = 32; -#endif // The number of warps in a CTA. - const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; + const int WARPS_PER_CTA = THREADS_PER_CTA / C10_WARP_SIZE; // The number of threads per pixel. const int THREADS_PER_PIXEL = 16; // The number of elements per ldg. @@ -564,13 +558,13 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw, const int MAX_BLOCK_Y = 256; const int MAX_OFFSET = REDUCE_OPS*MAX_BLOCK_Y; // The warp decomposition. - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const int lane_id = threadIdx.x % THREADS_PER_WARP; + const int warp_id = threadIdx.x / C10_WARP_SIZE; + const int lane_id = threadIdx.x % C10_WARP_SIZE; // total size of data per sync iter const int data_total = MAX_OFFSET*THREADS_PER_PIXEL*ELEMENTS_PER_LDG*2; #ifdef USE_ROCM - for (int offset = THREADS_PER_PIXEL; offset <= THREADS_PER_WARP >> 1; offset <<= 1) { + for (int offset = THREADS_PER_PIXEL; offset <= C10_WARP_SIZE >> 1; offset <<= 1) { for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { x[i] += shfl_sync(x[i], offset + lane_id); } @@ -598,16 +592,16 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw, #pragma unroll for (int offset = 1; - offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) { + offset < WARPS_PER_CTA/(C10_WARP_SIZE / THREADS_PER_PIXEL); ++offset) { float y[ELEMENTS_PER_LDG]; // Read the mean and variance from the other pixel. - read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP); + read_from_smem(y, smem, threadIdx.x + offset*C10_WARP_SIZE); // Compute the updated sum. add(x, y); } #ifdef USE_ROCM - for (int offset = THREADS_PER_WARP >> 1; offset >= THREADS_PER_PIXEL; offset >>= 1) { for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { + for (int offset = C10_WARP_SIZE >> 1; offset >= THREADS_PER_PIXEL; offset >>= 1) { for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { x[i] += shfl_sync(x[i], offset + lane_id); } } @@ -681,21 +675,15 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw, template< int THREADS_PER_CTA > DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) { - // The size of a warp. -#ifdef USE_ROCM - const int THREADS_PER_WARP = 64; -#else - const int THREADS_PER_WARP = 32; -#endif // The number of warps in a CTA. - const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; + const int WARPS_PER_CTA = THREADS_PER_CTA / C10_WARP_SIZE; // The number of threads per pixel. const int THREADS_PER_PIXEL = 8; // The number of elements per ldg. const int ELEMENTS_PER_LDG = 4; // The warp decomposition. - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const int lane_id = threadIdx.x % THREADS_PER_WARP; + const int warp_id = threadIdx.x / C10_WARP_SIZE; + const int lane_id = threadIdx.x % C10_WARP_SIZE; #pragma unroll for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { @@ -718,10 +706,10 @@ DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) { #pragma unroll for (int offset = 1; - offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) { + offset < WARPS_PER_CTA/(C10_WARP_SIZE / THREADS_PER_PIXEL); ++offset) { float y[ELEMENTS_PER_LDG]; // Read the mean and variance from the other pixel. - read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP); + read_from_smem(y, smem, threadIdx.x + offset*C10_WARP_SIZE); // Compute the updated sum. add(x, y); } @@ -745,20 +733,14 @@ DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) { template< int THREADS_PER_CTA, int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG > DEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) { - // The size of a warp. -#ifdef USE_ROCM - const int THREADS_PER_WARP = 64; -#else - const int THREADS_PER_WARP = 32; -#endif - const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; + const int WARPS_PER_CTA = THREADS_PER_CTA / C10_WARP_SIZE; // The warp decomposition. - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const int lane_id = threadIdx.x % THREADS_PER_WARP; + const int warp_id = threadIdx.x / C10_WARP_SIZE; + const int lane_id = threadIdx.x % C10_WARP_SIZE; // total size of data per sync iter #ifdef USE_ROCM - for (int offset = THREADS_PER_PIXEL; offset <= THREADS_PER_WARP >> 1; offset <<= 1) { + for (int offset = THREADS_PER_PIXEL; offset <= C10_WARP_SIZE >> 1; offset <<= 1) { for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { x[i] += shfl_sync(x[i], offset + lane_id); } @@ -786,16 +768,16 @@ DEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], in #pragma unroll for (int offset = 1; - offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) { + offset < WARPS_PER_CTA/(C10_WARP_SIZE / THREADS_PER_PIXEL); ++offset) { float y[ELEMENTS_PER_LDG]; // Read the mean and variance from the other pixel. - read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP); + read_from_smem(y, smem, threadIdx.x + offset*C10_WARP_SIZE); // Compute the updated sum. add(x, y); } #ifdef USE_ROCM - for (int offset = THREADS_PER_WARP >> 1; offset >= THREADS_PER_PIXEL; offset >>= 1) { + for (int offset = C10_WARP_SIZE >> 1; offset >= THREADS_PER_PIXEL; offset >>= 1) { for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { x[i] += shfl_sync(x[i], offset + lane_id); } diff --git a/apex/contrib/csrc/multihead_attn/softmax.cuh b/apex/contrib/csrc/multihead_attn/softmax.cuh index d6fa55553..6e7da0f71 100644 --- a/apex/contrib/csrc/multihead_attn/softmax.cuh +++ b/apex/contrib/csrc/multihead_attn/softmax.cuh @@ -17,6 +17,7 @@ #include #include #include +#include #ifdef USE_ROCM #define APEX_WARP_SHFL_XOR(mask, value, offset, width) __shfl_xor(value, offset, width) @@ -235,7 +236,7 @@ bool warp_softmax_kernel(int log2_elements, int &warp_size, softmax_forward_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -654,7 +655,7 @@ bool warp_additive_masked_softmax_dropout_kernel( &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -948,7 +949,7 @@ bool warp_additive_masked_softmax_kernel( additive_masked_softmax_forward_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -1240,7 +1241,7 @@ bool warp_masked_softmax_kernel( masked_softmax_forward_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -1488,7 +1489,7 @@ bool warp_time_masked_softmax_kernel( time_masked_softmax_forward_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -1741,7 +1742,7 @@ void dispatch_masked_scale_softmax_backward_masked_out( // This value must match the WARP_SIZE constexpr value computed inside // softmax_warp_backward. int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // This value must match the WARP_BATCH constexpr value computed inside // softmax_warp_backward. @@ -1855,7 +1856,8 @@ void dispatch_masked_scale_softmax_backward_masked_out_stream( // This value must match the WARP_SIZE constexpr value computed inside // softmax_warp_backward. int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); + // This value must match the WARP_BATCH constexpr value computed inside // softmax_warp_backward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -2254,7 +2256,7 @@ bool masked_scale_softmax_warp_backward_recompute_kernel( is_log_softmax> &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -2392,7 +2394,8 @@ void dispatch_masked_scale_softmax_backward_stream( // This value must match the WARP_SIZE constexpr value computed inside // softmax_warp_backward. int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); + // This value must match the WARP_BATCH constexpr value computed inside // softmax_warp_backward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -2593,7 +2596,7 @@ void dispatch_softmax_backward_fused_native( // This value must match the WARP_SIZE constexpr value computed inside // softmax_warp_backward. int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // This value must match the WARP_BATCH constexpr value computed inside // softmax_warp_backward. @@ -2805,7 +2808,7 @@ bool warp_softmax_backward_kernel( softmax_backward_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -3048,7 +3051,7 @@ bool warp_masked_softmax_backward_kernel( masked_softmax_backward_func &kernel) { // determine size of a warp const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // determine how many batches a warp should process. batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; diff --git a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu b/apex/contrib/csrc/transducer/transducer_joint_kernel.cu index 477c1de58..05c64320b 100755 --- a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu +++ b/apex/contrib/csrc/transducer/transducer_joint_kernel.cu @@ -729,8 +729,8 @@ std::vector transducer_joint_cuda_forward( TORCH_CHECK(opt == 0 or opt == 1, "Got an invalid optimization level ", opt); // Simple heuristics - const int numThread = std::min(128, (static_cast(hiddenSize)+C10_WARP_SIZE-1) - / C10_WARP_SIZE * C10_WARP_SIZE); + const int numThread = std::min(128, (static_cast(hiddenSize)+at::cuda::warp_size()-1) + / at::cuda::warp_size() * at::cuda::warp_size()); if (opt == 0){ // vanilla kernel @@ -862,7 +862,7 @@ std::vector transducer_joint_cuda_backward( const int hiddenSize = grad.size(-1); const auto deviceProperties = at::cuda::getCurrentDeviceProperties(); - const int maxNumWarp = deviceProperties->maxThreadsPerBlock / C10_WARP_SIZE; + const int maxNumWarp = deviceProperties->maxThreadsPerBlock / at::cuda::warp_size(); torch::Tensor fGrad = torch::empty({batchSize, maxFLen, hiddenSize}, tensorOpt); torch::Tensor gGrad = torch::empty({batchSize, maxGLen, hiddenSize}, tensorOpt); @@ -880,8 +880,8 @@ std::vector transducer_joint_cuda_backward( // Need smem for transposing the partial sum. The partial sum is in a matrix of the shape // numWarp x warpSize - const int smemSize = numWarp * C10_WARP_SIZE; - const dim3 threads(C10_WARP_SIZE, numWarp, 1); + const int smemSize = numWarp * at::cuda::warp_size(); + const dim3 threads(at::cuda::warp_size(), numWarp, 1); AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_cuda_backward_kernel", ([&] { auto gradPtr = grad.data_ptr(); @@ -905,7 +905,7 @@ std::vector transducer_joint_cuda_backward( if (vectFactor > 1 and hiddenSize%vectFactor == 0 and memAlign){ // If vectorization helps and the alignment requirement is met, use the vectorized // kernel. For simplicity, hiddenSize needs to be a multiple vecFactor. - const dim3 blocks( (hiddenSize+C10_WARP_SIZE*vectFactor-1)/(C10_WARP_SIZE*vectFactor), + const dim3 blocks( (hiddenSize+at::cuda::warp_size()*vectFactor-1)/(at::cuda::warp_size()*vectFactor), maxFLen+maxGLen, batchSize); if (masked){ @@ -944,7 +944,7 @@ std::vector transducer_joint_cuda_backward( } } else{ - const dim3 blocks((hiddenSize+C10_WARP_SIZE-1)/C10_WARP_SIZE, + const dim3 blocks((hiddenSize+at::cuda::warp_size()-1)/at::cuda::warp_size(), maxFLen + maxGLen, batchSize); if (masked){ transducer_joint_combined_backward diff --git a/apex/contrib/csrc/xentropy/xentropy_kernel.cu b/apex/contrib/csrc/xentropy/xentropy_kernel.cu index f2711f6e1..4c9f1c4ed 100644 --- a/apex/contrib/csrc/xentropy/xentropy_kernel.cu +++ b/apex/contrib/csrc/xentropy/xentropy_kernel.cu @@ -72,6 +72,7 @@ */ #include #include +#include #include #include @@ -82,10 +83,8 @@ #define ALIGN_BYTES 16 #ifdef USE_ROCM -#define WARP_SIZE 64 #define SYNCWARP(mask) #else -#define WARP_SIZE 32 #define SYNCWARP(mask) __syncwarp(mask) #endif @@ -130,7 +129,7 @@ inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { uint64_t max_block_size = std::min(dim_size / ILP, static_cast(max_threads)); while (block_size < (max_block_size/2)) block_size *= 2; // Launch at least a single warp - the kernel assumes that. - block_size = std::max(block_size, static_cast(WARP_SIZE)); + block_size = std::max(block_size, static_cast(at::cuda::warp_size())); return dim3(block_size); } @@ -199,13 +198,13 @@ blockReduce(AccumT* smem, AccumT val, AccumT warpVal = defaultVal; // First warp will perform per-warp reductions for the remaining warps - uint32_t mask = (((uint64_t)1) << (blockDim.x / WARP_SIZE)) - 1; - if (threadIdx.x < WARP_SIZE) { - int lane = threadIdx.x % WARP_SIZE; - if (lane < blockDim.x / WARP_SIZE) { + uint32_t mask = (((uint64_t)1) << (blockDim.x / C10_WARP_SIZE)) - 1; + if (threadIdx.x < C10_WARP_SIZE) { + int lane = threadIdx.x % C10_WARP_SIZE; + if (lane < blockDim.x / C10_WARP_SIZE) { #pragma unroll - for (int i = 0; i < WARP_SIZE; ++i) { - warpVal = r(warpVal, smem[lane * WARP_SIZE + i]); + for (int i = 0; i < C10_WARP_SIZE; ++i) { + warpVal = r(warpVal, smem[lane * C10_WARP_SIZE + i]); } SYNCWARP(mask); smem[lane] = warpVal; @@ -218,7 +217,7 @@ blockReduce(AccumT* smem, AccumT val, AccumT blockVal = defaultVal; if (threadIdx.x == 0) { - for (int i = 0; i < blockDim.x / WARP_SIZE; ++i) { + for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) { blockVal = r(blockVal, smem[i]); } smem[0] = blockVal; @@ -253,14 +252,14 @@ blockReduce(AccumT* smem, AccumT warpVal2 = defaultVal2; // First warp will perform per-warp reductions for the remaining warps - uint32_t mask = (((uint64_t)1) << (blockDim.x / WARP_SIZE)) - 1; - if (threadIdx.x < WARP_SIZE) { - int lane = threadIdx.x % WARP_SIZE; - if (lane < blockDim.x / WARP_SIZE) { + uint32_t mask = (((uint64_t)1) << (blockDim.x / C10_WARP_SIZE)) - 1; + if (threadIdx.x < C10_WARP_SIZE) { + int lane = threadIdx.x % C10_WARP_SIZE; + if (lane < blockDim.x / C10_WARP_SIZE) { #pragma unroll - for (int i = 0; i < WARP_SIZE; ++i) { - warpVal1 = r1(warpVal1, smem[lane * WARP_SIZE + i]); - warpVal2 = r2(warpVal2, smem[lane * WARP_SIZE + i + blockDim.x]); + for (int i = 0; i < C10_WARP_SIZE; ++i) { + warpVal1 = r1(warpVal1, smem[lane * C10_WARP_SIZE + i]); + warpVal2 = r2(warpVal2, smem[lane * C10_WARP_SIZE + i + blockDim.x]); } SYNCWARP(mask); smem[lane] = warpVal1; @@ -275,7 +274,7 @@ blockReduce(AccumT* smem, AccumT blockVal2 = defaultVal2; if (threadIdx.x == 0) { - for (int i = 0; i < blockDim.x / WARP_SIZE; ++i) { + for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) { blockVal1 = r1(blockVal1, smem[i]); blockVal2 = r2(blockVal2, smem[i + blockDim.x]); } diff --git a/csrc/megatron/fused_rotary_positional_embedding.h b/csrc/megatron/fused_rotary_positional_embedding.h index d2881b4a7..1f031c338 100644 --- a/csrc/megatron/fused_rotary_positional_embedding.h +++ b/csrc/megatron/fused_rotary_positional_embedding.h @@ -335,7 +335,7 @@ void dispatch_fused_rope_forward(const int s, const int b, const int h, int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); - dim3 threads(C10_WARP_SIZE, warps_per_block); + dim3 threads(at::cuda::warp_size(), warps_per_block); fused_rope_forward<<>>( h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, @@ -356,7 +356,7 @@ void dispatch_fused_rope_backward(const int s, const int b, const int h, int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); - dim3 threads(C10_WARP_SIZE, warps_per_block); + dim3 threads(at::cuda::warp_size(), warps_per_block); fused_rope_backward<<>>( h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, @@ -375,7 +375,7 @@ void dispatch_fused_rope_cached_forward( int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); - dim3 threads(C10_WARP_SIZE, warps_per_block); + dim3 threads(at::cuda::warp_size(), warps_per_block); fused_rope_cached_forward<<>>( h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, @@ -394,7 +394,7 @@ void dispatch_fused_rope_cached_backward( int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); - dim3 threads(C10_WARP_SIZE, warps_per_block); + dim3 threads(at::cuda::warp_size(), warps_per_block); fused_rope_cached_backward<<>>( h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, @@ -415,7 +415,7 @@ void dispatch_fused_rope_thd_forward(const int max_s, const int b, const int h, int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(max_s, b); - dim3 threads(C10_WARP_SIZE, warps_per_block); + dim3 threads(at::cuda::warp_size(), warps_per_block); fused_rope_thd_forward<<>>( h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, @@ -434,7 +434,7 @@ void dispatch_fused_rope_thd_backward( int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(max_s, b); - dim3 threads(C10_WARP_SIZE, warps_per_block); + dim3 threads(at::cuda::warp_size(), warps_per_block); fused_rope_thd_backward<<>>( h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, @@ -454,7 +454,7 @@ void dispatch_fused_rope_2d_forward( int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(ih, iw, b); - dim3 threads(C10_WARP_SIZE, warps_per_block); + dim3 threads(at::cuda::warp_size(), warps_per_block); fused_rope_2d_forward<<>>( ih, iw, h, d, stride_b, stride_ih, stride_iw, stride_h, stride_d, @@ -476,7 +476,7 @@ void dispatch_fused_rope_2d_backward( int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(ih, iw, b); - dim3 threads(C10_WARP_SIZE, warps_per_block); + dim3 threads(at::cuda::warp_size(), warps_per_block); fused_rope_2d_backward<<>>( ih, iw, h, d, stride_b, stride_ih, stride_iw, stride_h, stride_d, diff --git a/csrc/megatron/generic_scaled_masked_softmax.h b/csrc/megatron/generic_scaled_masked_softmax.h index 4ff50feb8..79fbc561d 100644 --- a/csrc/megatron/generic_scaled_masked_softmax.h +++ b/csrc/megatron/generic_scaled_masked_softmax.h @@ -23,6 +23,7 @@ #include #include #include +#include namespace { @@ -172,7 +173,7 @@ void dispatch_scaled_masked_softmax_backward_new( int batch_count = batches * attn_heads * query_seq_len; // use 128 threads per block to maximize gpu utilization constexpr int threads_per_block = 128; - int num_warps = (key_seq_len - 1) / C10_WARP_SIZE + 1; + int num_warps = (key_seq_len - 1) / at::cuda::warp_size() + 1; dim3 blocks(batch_count, 1, 1); dim3 threads(threads_per_block, 1, 1); @@ -374,7 +375,7 @@ void dispatch_scaled_masked_softmax_forward_new( constexpr int threads_per_block = 128; // calculate the needed shared memory - int num_warps = (key_seq_len - 1) / C10_WARP_SIZE + 1; + int num_warps = (key_seq_len - 1) / at::cuda::warp_size() + 1; dim3 blocks(batch_count, 1, 1); dim3 threads(threads_per_block, 1, 1); diff --git a/csrc/megatron/scaled_masked_softmax.h b/csrc/megatron/scaled_masked_softmax.h index f6e47d0b0..2674e1f54 100644 --- a/csrc/megatron/scaled_masked_softmax.h +++ b/csrc/megatron/scaled_masked_softmax.h @@ -663,7 +663,7 @@ void dispatch_scaled_masked_softmax_backward( int batch_count = batches * attn_heads * query_seq_len; // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; diff --git a/csrc/welford.cu b/csrc/welford.cu index dd49b81f6..fabee1999 100644 --- a/csrc/welford.cu +++ b/csrc/welford.cu @@ -2,6 +2,7 @@ #include #include #include +#include #include #include @@ -44,17 +45,11 @@ __host__ __forceinline__ int h_last_pow2(unsigned int n) { return n - (n >> 1); } -#ifdef USE_ROCM -#define WARP_SIZE 64 -#else -#define WARP_SIZE 32 -#endif - template __device__ __forceinline__ T warp_reduce_sum(T val) { #pragma unroll - for(int i = WARP_SIZE/2; i > 0; i >>= 1) + for(int i = C10_WARP_SIZE/2; i > 0; i >>= 1) val = val + SHFL_DOWN(0xffffffff, val, i); return val; } @@ -64,17 +59,17 @@ __device__ __forceinline__ T reduce_block(T *x, T val) { int tid = threadIdx.y*blockDim.x + threadIdx.x; int blockSize = blockDim.x * blockDim.y; - int lane = tid % WARP_SIZE; - int wid = tid / WARP_SIZE; + int lane = tid % C10_WARP_SIZE; + int wid = tid / C10_WARP_SIZE; - if (blockSize > WARP_SIZE) { + if (blockSize > C10_WARP_SIZE) { val = warp_reduce_sum(val); if (lane == 0) x[wid] = val; __syncthreads(); - val = (tid < blockSize / WARP_SIZE? x[lane] : T(0)); + val = (tid < blockSize / C10_WARP_SIZE? x[lane] : T(0)); } if(wid==0) val = warp_reduce_sum(val); @@ -84,7 +79,6 @@ __device__ __forceinline__ T reduce_block(T *x, T val) #define ELEMENTS_PER_ITER 4 // enables concurrency within each thread to hide latency #define ELEMENTS_PER_THREAD 16 -#define OPTIMAL_TILE_W WARP_SIZE #define MAX_H_BLOCK 128 #define MAX_BLOCK_SIZE 512 @@ -98,7 +92,7 @@ __host__ void flexible_launch_configs( dim3 &block, dim3 &grid, const bool coop_flag = false) { - int block_x = std::min(h_last_pow2(stride), OPTIMAL_TILE_W); + int block_x = std::min(h_last_pow2(stride), at::cuda::warp_size()); int block_y = std::min(h_last_pow2(div_ru(reduction , ELEMENTS_PER_THREAD)), MAX_BLOCK_SIZE / block_x); if (block_x * block_y != MAX_BLOCK_SIZE) { @@ -138,7 +132,7 @@ template __device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num) { #pragma unroll - for(int i = WARP_SIZE/2; i > 0; i >>= 1) { + for(int i = C10_WARP_SIZE/2; i > 0; i >>= 1) { auto num_new = SHFL_DOWN(0xffffffff, num, i); auto mean_new = SHFL_DOWN(0xffffffff, mean, i); auto m2n_new = SHFL_DOWN(0xffffffff, m2n, i); @@ -156,10 +150,10 @@ __device__ void welford_reduce_mean_m2n( int block_size, int thread_id) { - int lane = thread_id % WARP_SIZE; - int wid = thread_id / WARP_SIZE; + int lane = thread_id % C10_WARP_SIZE; + int wid = thread_id / C10_WARP_SIZE; - if (block_size > WARP_SIZE) { + if (block_size > C10_WARP_SIZE) { warp_reduce_mean_m2n(mean, m2n, num); if (lane == 0) { x[wid*2] = mean; @@ -169,9 +163,9 @@ __device__ void welford_reduce_mean_m2n( __syncthreads(); if (wid == 0) { - mean = (thread_id < block_size / WARP_SIZE)? x[lane*2] : T(0); - m2n = (thread_id < block_size / WARP_SIZE)? x[lane*2+1] : T(0); - num = (thread_id < block_size / WARP_SIZE)? count[lane] : int(0); + mean = (thread_id < block_size / C10_WARP_SIZE)? x[lane*2] : T(0); + m2n = (thread_id < block_size / C10_WARP_SIZE)? x[lane*2+1] : T(0); + num = (thread_id < block_size / C10_WARP_SIZE)? count[lane] : int(0); } } @@ -295,8 +289,8 @@ __global__ void welford_kernel( } } - static __shared__ int s_mem[WARP_SIZE]; - static __shared__ accscalar_t s_mem_ac[WARP_SIZE*2]; + static __shared__ int s_mem[C10_WARP_SIZE]; + static __shared__ accscalar_t s_mem_ac[C10_WARP_SIZE*2]; welford_reduce_mean_m2n(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id); @@ -353,7 +347,7 @@ __global__ void reduce_bn_kernel( const int bs, const int fs, const int ss) { - static __shared__ int s_mem[WARP_SIZE]; + static __shared__ int s_mem[C10_WARP_SIZE]; //int total_item_num = bs * ss; int thread_id = threadIdx.y*blockDim.x + threadIdx.x; @@ -952,7 +946,7 @@ std::vector welford_mean_var_CUDA(const at::Tensor input) { at::Tensor out_var_biased = at::empty({feature_size}, input.options().dtype(scalar_type)); at::Tensor out_mean = at::empty({feature_size}, input.options().dtype(scalar_type)); - int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / WARP_SIZE)); + int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / at::cuda::warp_size())); int block_x = max(1, min(MAX_BLOCK_SIZE / block_y, h_last_pow2(space_size))); const dim3 block(block_x, block_y); const dim3 grid(feature_size); @@ -988,7 +982,7 @@ at::Tensor batchnorm_forward_CUDA( auto space_size = get_tensor_spatial_size(input); - int block_x = max(WARP_SIZE, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4)); + int block_x = max(at::cuda::warp_size(), min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4)); int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4)); const dim3 block(block_x, block_y); int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x)); @@ -1061,7 +1055,7 @@ std::vector reduce_bn_CUDA( auto space_size = get_tensor_spatial_size(input); - int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE/ WARP_SIZE)); + int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE/ at::cuda::warp_size())); int block_x = max(1, min(MAX_BLOCK_SIZE/ block_y, h_last_pow2(space_size))); const dim3 block(block_x, block_y); const dim3 grid(feature_size); @@ -1128,7 +1122,7 @@ at::Tensor batchnorm_backward_CUDA( auto space_size = get_tensor_spatial_size(input); - int block_x = max(WARP_SIZE, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4)); + int block_x = max(at::cuda::warp_size(), min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4)); int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4)); const dim3 block(block_x, block_y); int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x));