Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 5 additions & 18 deletions apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#endif
#include <stdint.h>
#include <algorithm>
#include <c10/macros/Macros.h>

#ifdef USE_ROCM
using bitmask_t = uint64_t;
Expand Down Expand Up @@ -62,20 +63,6 @@ DEVICE_FUNCTION void syncwarp() {

////////////////////////////////////////////////////////////////////////////////////////////////////

DEVICE_FUNCTION constexpr int get_warp_size() {
#ifdef USE_ROCM
#if defined(__GFX9__)
return 64;
#else
return 32;
#endif
#else
return warpSize;
#endif
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T>
DEVICE_FUNCTION T shfl_sync(T var, int src_lane) {
#ifdef USE_ROCM
Expand Down Expand Up @@ -1075,7 +1062,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;

// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/get_warp_size())*ELEMENTS_PER_LDG];
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG];

// Compute the NHW coordinate of the thread in the CTA.
const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL;
Expand Down Expand Up @@ -1802,7 +1789,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;

// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/get_warp_size())*ELEMENTS_PER_LDG];
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG];

// The adapter for the storage.
typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;
Expand Down Expand Up @@ -2190,7 +2177,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;

// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/get_warp_size())*ELEMENTS_PER_LDG];
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG];

// The adapter for the storage.
typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;
Expand Down Expand Up @@ -2602,7 +2589,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG;

// Shared memory to do CTA-wide parallel sums.
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/get_warp_size())*ELEMENTS_PER_LDG];
__shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG];

// The adapter for the storage.
typedef PackedStorage<Storage, ELEMENTS_PER_LDG> PackedStorage_;
Expand Down