diff --git a/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h b/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h index 0fc0faf7d..0dd47a340 100644 --- a/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h +++ b/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h @@ -33,6 +33,7 @@ #endif #include #include +#include #ifdef USE_ROCM using bitmask_t = uint64_t; @@ -62,20 +63,6 @@ DEVICE_FUNCTION void syncwarp() { //////////////////////////////////////////////////////////////////////////////////////////////////// -DEVICE_FUNCTION constexpr int get_warp_size() { -#ifdef USE_ROCM - #if defined(__GFX9__) - return 64; - #else - return 32; - #endif -#else - return warpSize; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template DEVICE_FUNCTION T shfl_sync(T var, int src_lane) { #ifdef USE_ROCM @@ -1075,7 +1062,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/get_warp_size())*ELEMENTS_PER_LDG]; + __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG]; // Compute the NHW coordinate of the thread in the CTA. const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; @@ -1802,7 +1789,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/get_warp_size())*ELEMENTS_PER_LDG]; + __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG]; // The adapter for the storage. typedef PackedStorage PackedStorage_; @@ -2190,7 +2177,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/get_warp_size())*ELEMENTS_PER_LDG]; + __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG]; // The adapter for the storage. typedef PackedStorage PackedStorage_; @@ -2602,7 +2589,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/get_warp_size())*ELEMENTS_PER_LDG]; + __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG]; // The adapter for the storage. typedef PackedStorage PackedStorage_;