From 9d7ba994cdfff554c0dd15d2ae6da493c57f0eca Mon Sep 17 00:00:00 2001 From: Ioannis Assiouras <38722728+iassiour@users.noreply.github.com> Date: Sun, 6 Jul 2025 01:47:30 +0100 Subject: [PATCH] Do not use warpSize as a constexpr in nhwc_batch_norm_kernel.h In ROCm 7.0, the warpSize variable is no longer constexpr. This commit replaces the variable use with the correct values based on the architecture we're running on. --- .../csrc/groupbn/nhwc_batch_norm_kernel.h | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h b/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h index f1fdd5241..0fc0faf7d 100644 --- a/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h +++ b/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h @@ -62,6 +62,20 @@ DEVICE_FUNCTION void syncwarp() { //////////////////////////////////////////////////////////////////////////////////////////////////// +DEVICE_FUNCTION constexpr int get_warp_size() { +#ifdef USE_ROCM + #if defined(__GFX9__) + return 64; + #else + return 32; + #endif +#else + return warpSize; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template DEVICE_FUNCTION T shfl_sync(T var, int src_lane) { #ifdef USE_ROCM @@ -1061,7 +1075,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG]; + __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/get_warp_size())*ELEMENTS_PER_LDG]; // Compute the NHW coordinate of the thread in the CTA. const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; @@ -1788,7 +1802,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG]; + __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/get_warp_size())*ELEMENTS_PER_LDG]; // The adapter for the storage. typedef PackedStorage PackedStorage_; @@ -2176,7 +2190,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG]; + __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/get_warp_size())*ELEMENTS_PER_LDG]; // The adapter for the storage. typedef PackedStorage PackedStorage_; @@ -2588,7 +2602,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG]; + __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/get_warp_size())*ELEMENTS_PER_LDG]; // The adapter for the storage. typedef PackedStorage PackedStorage_;