From d9d020aeb2427914e76fabda27270f863a0a1412 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Fri, 20 Jun 2025 21:19:59 +0000 Subject: [PATCH 1/2] remove warpSize usage on host side --- aten/src/ATen/native/cuda/SoftMax.cu | 11 ++++++----- aten/src/ATen/native/cuda/TensorTopK.cu | 4 ++++ aten/src/ATen/native/cuda/block_reduce.cuh | 6 ++++++ c10/macros/Macros.h | 16 +++++++++++++++- 4 files changed, 31 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu index 4aca753a510b8..c908d1d525c52 100644 --- a/aten/src/ATen/native/cuda/SoftMax.cu +++ b/aten/src/ATen/native/cuda/SoftMax.cu @@ -177,15 +177,16 @@ inline dim3 SoftMaxForward_getBlockSize(uint64_t dim_size) { uint64_t block_size = 1; uint64_t max_block_size = std::min(dim_size, static_cast(max_threads)); - // We need a block size that is a multiple of C10_WARP_SIZE in order + // We need a block size that is a multiple of at::cuda::warp_size() in order // to perform block size reductions using warp shuffle instructions. // Since max_threads is also a multiple of C10_WARPS_SIZE we do not // risk creating a block size larger than the limit. - if (max_block_size % C10_WARP_SIZE == 0) { + int warp_size = at::cuda::warp_size(); + if (max_block_size % warp_size == 0) { block_size = max_block_size; } else { - block_size = (max_block_size / C10_WARP_SIZE + 1) * C10_WARP_SIZE; + block_size = (max_block_size / warp_size + 1) * warp_size; } return dim3(block_size); @@ -859,7 +860,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t } else { constexpr int ILP = sizeof(float4) / sizeof(scalar_t); dim3 block = SoftMaxForward_getBlockSize(dim_size); - size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t); + size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t); auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - smem_reduction_sz) / sizeof(scalar_t); @@ -895,7 +896,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t } else { constexpr int ILP = sizeof(float4) / sizeof(scalar_t); dim3 block = SoftMaxForward_getBlockSize(dim_size); - size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t); + size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t); auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - smem_reduction_sz) / sizeof(scalar_t); diff --git a/aten/src/ATen/native/cuda/TensorTopK.cu b/aten/src/ATen/native/cuda/TensorTopK.cu index d06efa6635131..3b9b6b070b5e6 100644 --- a/aten/src/ATen/native/cuda/TensorTopK.cu +++ b/aten/src/ATen/native/cuda/TensorTopK.cu @@ -456,8 +456,12 @@ __global__ void computeBlockwiseWithinKCounts( warp_counts[warp] = count; } __syncthreads(); +#ifdef USE_ROCM + CUDA_KERNEL_ASSERT(RADIX_DIGITS < C10_WARP_SIZE * C10_WARP_SIZE); +#else static_assert(RADIX_DIGITS < C10_WARP_SIZE * C10_WARP_SIZE, "Assuming only 1 warp is needed for final reduction"); +#endif if (warp != 0) { return; } diff --git a/aten/src/ATen/native/cuda/block_reduce.cuh b/aten/src/ATen/native/cuda/block_reduce.cuh index df757a11761bb..112c6ab952574 100644 --- a/aten/src/ATen/native/cuda/block_reduce.cuh +++ b/aten/src/ATen/native/cuda/block_reduce.cuh @@ -14,7 +14,13 @@ constexpr int kCUDABlockReduceNumThreads = 512; // of which reduces C10_WARP_SIZE elements. So, at most // C10_WARP_SIZE**2 elements can be reduced at a time. // NOTE: This is >= the max block size on current hardware anyway (1024). +// ROCm NOTE: C10_WARP_SIZE should only be used inside device functions, +// and kCUDABlockReduceMaxThreads is a host-side variable. +#ifdef USE_ROCM +static const int kCUDABlockReduceMaxThreads = at::cuda::warp_size() * at::cuda::warp_size(); +#else constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE; +#endif // Sums `val` across all threads in a warp. // diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index a66933823d80f..e2b3fb0e588d8 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -311,7 +311,21 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; #endif #if defined(USE_ROCM) -#define C10_WARP_SIZE warpSize // = 64 or 32 (Defined in hip_runtime.h) +// C10_WARP_SIZE is only allowed for device code. +// Host code _must_ use at::cuda::warp_size() +// HIP header used to define warpSize as a constexpr that was either 32 or 64 +// depending on the target device, and then always set it to 64 for host code. +// Host pass of HIP compiler needs C10_WARP_SIZE defined to _something_ so we +// set it to something unreasonable to trigger obvious host code errors. +#if defined(__HIP_DEVICE_COMPILE__) +#if defined(__GFX9__) +static constexpr int C10_WARP_SIZE = 64; +#else // __GFX9__ +static constexpr int C10_WARP_SIZE = 32; +#endif // __GFX9__ +#else +static constexpr int C10_WARP_SIZE = 1; +#endif // __HIP_DEVICE_COMPILE__ #else #define C10_WARP_SIZE 32 #endif From d6b46125654874f27e05d3272f7fe0478380dfb5 Mon Sep 17 00:00:00 2001 From: Ethan Wee Date: Wed, 25 Jun 2025 15:10:51 -0700 Subject: [PATCH 2/2] [rocm7.0_internal_testing] Prevent static initialization of at::cuda::warp_size() (#2293) Fixes SWDEV-540240, SWDEV-540309, SWDEV-539989 ``` ... ``` https://github.com/ROCm/pytorch/commit/80cca7006d94df97ee932fd5903ed20c08c2eb34 created a static global variable that used `at::cuda::warp_size()` to initialize its value, which needs GPUs to be visible to query device properties. However, GPUs are not present on CPU-only build systems. Convert static variable into a static function, thus preventing static initialization. http://rocm-ci.amd.com/job/pyt_whl_docker_mainline/1461/artifact/build_artifacts.txt/*view*/ Ran microbenchmark to confirm basic functionality: ``` root@ubb4-rack-22:/var/lib/jenkins/pytorch-micro-benchmarking# python3 micro_benchmarking_pytorch.py --network resnet50 INFO: running forward and backward for warmup. INFO: running the benchmark.. OK: finished running benchmark.. --------------------SUMMARY-------------------------- Microbenchmark for network : resnet50 Num devices: 1 Dtype: FP32 Mini batch size [img] : 64 Time per mini-batch : 0.10158218145370483 Throughput [img/sec] : 630.0317544289736= ``` --- aten/src/ATen/native/cuda/Embedding.cu | 2 +- aten/src/ATen/native/cuda/MultinomialKernel.cu | 2 +- aten/src/ATen/native/cuda/TensorModeKernel.cu | 2 +- aten/src/ATen/native/cuda/block_reduce.cuh | 8 ++++++-- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index b8fb51304e4b0..5a02d199ed6b0 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -369,7 +369,7 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices, int warp_size = at::cuda::warp_size(); TORCH_INTERNAL_ASSERT(num_threads() % warp_size == 0 && - num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads, + num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads(), "BlockReduceSum requires all warps be active"); const int64_t *num_unique_indices_ptr = num_unique_indices.const_data_ptr(); dim3 grid = unique_indices.numel(); diff --git a/aten/src/ATen/native/cuda/MultinomialKernel.cu b/aten/src/ATen/native/cuda/MultinomialKernel.cu index 3e67f5ad5bfbe..72374095baac2 100644 --- a/aten/src/ATen/native/cuda/MultinomialKernel.cu +++ b/aten/src/ATen/native/cuda/MultinomialKernel.cu @@ -86,7 +86,7 @@ void renormRows(Tensor& t) { TORCH_CHECK(props != nullptr); int numSM = props->multiProcessorCount; const int64_t maxThreads = std::min( - props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads); + props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads()); int warp_size = at::cuda::warp_size(); dim3 grid(rows < numSM * 4 ? rows : numSM * 4); diff --git a/aten/src/ATen/native/cuda/TensorModeKernel.cu b/aten/src/ATen/native/cuda/TensorModeKernel.cu index b848ed5748e5c..be158584cedb8 100644 --- a/aten/src/ATen/native/cuda/TensorModeKernel.cu +++ b/aten/src/ATen/native/cuda/TensorModeKernel.cu @@ -209,7 +209,7 @@ void handle_fused_mode( constexpr int num_threads = size / 2; int warp_size = at::cuda::warp_size(); TORCH_INTERNAL_ASSERT(num_threads % warp_size == 0 && - num_threads <= cuda_utils::kCUDABlockReduceMaxThreads, ""); + num_threads <= cuda_utils::kCUDABlockReduceMaxThreads(), ""); const auto memsize = (sizeof(scalar_t) * size) + (2 * size * sizeof(unsigned int)); compute_mode diff --git a/aten/src/ATen/native/cuda/block_reduce.cuh b/aten/src/ATen/native/cuda/block_reduce.cuh index 112c6ab952574..fc44b0f4a9da0 100644 --- a/aten/src/ATen/native/cuda/block_reduce.cuh +++ b/aten/src/ATen/native/cuda/block_reduce.cuh @@ -17,9 +17,13 @@ constexpr int kCUDABlockReduceNumThreads = 512; // ROCm NOTE: C10_WARP_SIZE should only be used inside device functions, // and kCUDABlockReduceMaxThreads is a host-side variable. #ifdef USE_ROCM -static const int kCUDABlockReduceMaxThreads = at::cuda::warp_size() * at::cuda::warp_size(); +static int kCUDABlockReduceMaxThreads() { + return at::cuda::warp_size() * at::cuda::warp_size(); +} #else -constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE; +constexpr int kCUDABlockReduceMaxThreads() { + return C10_WARP_SIZE * C10_WARP_SIZE; +} #endif // Sums `val` across all threads in a warp.