diff --git a/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h b/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h index f1fdd5241..0fc0faf7d 100644 --- a/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h +++ b/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h @@ -62,6 +62,20 @@ DEVICE_FUNCTION void syncwarp() { //////////////////////////////////////////////////////////////////////////////////////////////////// +DEVICE_FUNCTION constexpr int get_warp_size() { +#ifdef USE_ROCM + #if defined(__GFX9__) + return 64; + #else + return 32; + #endif +#else + return warpSize; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template DEVICE_FUNCTION T shfl_sync(T var, int src_lane) { #ifdef USE_ROCM @@ -1061,7 +1075,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG]; + __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/get_warp_size())*ELEMENTS_PER_LDG]; // Compute the NHW coordinate of the thread in the CTA. const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; @@ -1788,7 +1802,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG]; + __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/get_warp_size())*ELEMENTS_PER_LDG]; // The adapter for the storage. typedef PackedStorage PackedStorage_; @@ -2176,7 +2190,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG]; + __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/get_warp_size())*ELEMENTS_PER_LDG]; // The adapter for the storage. typedef PackedStorage PackedStorage_; @@ -2588,7 +2602,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG]; + __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/get_warp_size())*ELEMENTS_PER_LDG]; // The adapter for the storage. typedef PackedStorage PackedStorage_;