Skip to content

Commit 538a57d

Browse files
jeffdailyjithunnair-amd
authored andcommitted
remove warpSize usage on host side
1 parent 5c54a45 commit 538a57d

File tree

7 files changed

+36
-11
lines changed

7 files changed

+36
-11
lines changed

aten/src/ATen/native/cuda/SoftMax.cu

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,15 +177,16 @@ inline dim3 SoftMaxForward_getBlockSize(uint64_t dim_size) {
177177
uint64_t block_size = 1;
178178
uint64_t max_block_size = std::min(dim_size, static_cast<uint64_t>(max_threads));
179179

180-
// We need a block size that is a multiple of C10_WARP_SIZE in order
180+
// We need a block size that is a multiple of at::cuda::warp_size() in order
181181
// to perform block size reductions using warp shuffle instructions.
182182
// Since max_threads is also a multiple of C10_WARPS_SIZE we do not
183183
// risk creating a block size larger than the limit.
184184

185-
if (max_block_size % C10_WARP_SIZE == 0) {
185+
int warp_size = at::cuda::warp_size();
186+
if (max_block_size % warp_size == 0) {
186187
block_size = max_block_size;
187188
} else {
188-
block_size = (max_block_size / C10_WARP_SIZE + 1) * C10_WARP_SIZE;
189+
block_size = (max_block_size / warp_size + 1) * warp_size;
189190
}
190191

191192
return dim3(block_size);
@@ -859,7 +860,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
859860
} else {
860861
constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
861862
dim3 block = SoftMaxForward_getBlockSize(dim_size);
862-
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
863+
size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t);
863864
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
864865
smem_reduction_sz) / sizeof(scalar_t);
865866

@@ -895,7 +896,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
895896
} else {
896897
constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
897898
dim3 block = SoftMaxForward_getBlockSize(dim_size);
898-
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
899+
size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t);
899900
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
900901
smem_reduction_sz) / sizeof(scalar_t);
901902

aten/src/ATen/native/cuda/TensorTopK.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,12 @@ __global__ void computeBlockwiseWithinKCounts(
439439
warp_counts[warp] = count;
440440
}
441441
__syncthreads();
442+
#ifdef USE_ROCM
443+
CUDA_KERNEL_ASSERT(RADIX_DIGITS < C10_WARP_SIZE * C10_WARP_SIZE);
444+
#else
442445
static_assert(RADIX_DIGITS < C10_WARP_SIZE * C10_WARP_SIZE,
443446
"Assuming only 1 warp is needed for final reduction");
447+
#endif
444448
if (warp != 0) {
445449
return;
446450
}

aten/src/ATen/native/cuda/block_reduce.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@ constexpr int kCUDABlockReduceNumThreads = 512;
1212
// of which reduces C10_WARP_SIZE elements. So, at most
1313
// C10_WARP_SIZE**2 elements can be reduced at a time.
1414
// NOTE: This is >= the max block size on current hardware anyway (1024).
15+
// ROCm NOTE: C10_WARP_SIZE should only be used inside device functions,
16+
// and kCUDABlockReduceMaxThreads is a host-side variable.
17+
#ifdef USE_ROCM
18+
static const int kCUDABlockReduceMaxThreads = at::cuda::warp_size() * at::cuda::warp_size();
19+
#else
1520
constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE;
21+
#endif
1622

1723
// Sums `val` across all threads in a warp.
1824
//

c10/macros/Macros.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,21 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
312312
#endif
313313

314314
#if defined(USE_ROCM)
315-
#define C10_WARP_SIZE warpSize // = 64 or 32 (Defined in hip_runtime.h)
315+
// C10_WARP_SIZE is only allowed for device code.
316+
// Host code _must_ use at::cuda::warp_size()
317+
// HIP header used to define warpSize as a constexpr that was either 32 or 64
318+
// depending on the target device, and then always set it to 64 for host code.
319+
// Host pass of HIP compiler needs C10_WARP_SIZE defined to _something_ so we
320+
// set it to something unreasonable to trigger obvious host code errors.
321+
#if defined(__HIP_DEVICE_COMPILE__)
322+
#if defined(__GFX9__)
323+
static constexpr int C10_WARP_SIZE = 64;
324+
#else // __GFX9__
325+
static constexpr int C10_WARP_SIZE = 32;
326+
#endif // __GFX9__
327+
#else
328+
static constexpr int C10_WARP_SIZE = 1;
329+
#endif // __HIP_DEVICE_COMPILE__
316330
#else
317331
#define C10_WARP_SIZE 32
318332
#endif

third_party/composable_kernel

Submodule composable_kernel updated 683 files

torch/csrc/distributed/c10d/CUDASymmetricMemory.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ static __global__ void barrier_kernel(
488488
void CUDASymmetricMemory::barrier(int channel, size_t timeout_ms) {
489489
check_channel(channel, world_size_);
490490
c10::cuda::CUDAGuard guard(local_device_idx_);
491-
barrier_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
491+
barrier_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>(
492492
reinterpret_cast<uint32_t**>(signal_pads_dev_),
493493
channel,
494494
rank_,
@@ -526,7 +526,7 @@ void CUDASymmetricMemory::put_signal(
526526
size_t timeout_ms) {
527527
check_channel(channel, world_size_);
528528
c10::cuda::CUDAGuard guard(local_device_idx_);
529-
put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
529+
put_signal_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>(
530530
reinterpret_cast<uint32_t**>(signal_pads_dev_),
531531
dst_rank,
532532
channel,
@@ -570,7 +570,7 @@ void CUDASymmetricMemory::wait_signal(
570570
size_t timeout_ms) {
571571
check_channel(channel, world_size_);
572572
c10::cuda::CUDAGuard guard(local_device_idx_);
573-
wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
573+
wait_signal_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>(
574574
reinterpret_cast<uint32_t**>(signal_pads_dev_),
575575
src_rank,
576576
channel,

torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ void init_elementwise_launch_config(
104104
num_blocks = 1;
105105
num_threads = at::round_up(
106106
at::ceil_div(numel_per_split, numel_per_thread),
107-
static_cast<size_t>(C10_WARP_SIZE));
107+
static_cast<size_t>(at::cuda::warp_size()));
108108
} else {
109109
num_blocks = std::min(
110110
at::ceil_div(numel_per_split, max_num_threads * numel_per_thread),

0 commit comments

Comments
 (0)