From f2c120147200e966fda0c9ab6256bb4c56ddc677 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Mon, 14 Jul 2025 12:03:25 -0500 Subject: [PATCH] [release/2.6] Improve C10_WARP_SIZE compatibility (#2328) If compiling with HIPCC (i.e `__HIPCC__` is [defined](https://rocm.docs.amd.com/projects/HIP/en/docs-develop/how-to/hip_porting_guide.html#compiler-defines-summary)): * Define `C10_WARP_SIZE` to be non-constexpr `at::cuda::warp_size()` for host-compilation pass (as compared to `static constexpr int C10_WARP_SIZE = 1;` set in https://github.com/ROCm/pytorch/commit/538a57d13af24f96cb356b8313d47fd575c06a82) * Define `C10_WARP_SIZE` to be constexpr `64` for `__GFX9__`, and `32` otherwise, for device-compilation pass If not compiling with HIPCC: * Define `C10_WARP_SIZE` to be non-constexpr `at::cuda::warp_size()` For host-compilation cases where we need a constexpr value of warp size (eg. launch bounds), use `C10_WARP_SIZE_STATIC`, defined as `64` (Better to err on 64 for launch bounds) Fixes SWDEV-542227 --------- Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> --- .../src/ATen/native/cuda/layer_norm_kernel.cu | 5 ++++ .../sparse/cuda/SparseCUDAApplyUtils.cuh | 4 +++ c10/macros/Macros.h | 30 ++++++++++++++----- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index ee573e2e566f6..e34a3aca01b18 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -33,7 +33,12 @@ namespace at::native { namespace { constexpr int kCUDANumThreads = 256; +#ifdef USE_ROCM +// C10_WARP_SIZE is not constexpr for host code. +#define kWarpSize C10_WARP_SIZE +#else constexpr unsigned int kWarpSize = C10_WARP_SIZE; +#endif constexpr int vec_size = 4; //we could make it dependent on dtype, but that would lead to different results between float and low-p types // aligned vector generates vectorized load/store on CUDA (copy-pasted from MemoryAccess.cuh) diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh b/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh index c9412d74e9cda..693ca536a3198 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh @@ -242,7 +242,11 @@ __global__ void coalesceValuesKernel( // `if constexpr` when CUDA codes will be compiled under C++-17, see // gh-56055 for blockers. template +#ifdef USE_ROCM +C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE_STATIC*4) +#else C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE*4) +#endif __global__ void coalesceValuesKernel( int64_t *segment_offsets, int64_t *value_indices, bool *values, bool *newValues, diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index 847880c8a0458..be835ce108917 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -318,16 +318,32 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; // depending on the target device, and then always set it to 64 for host code. // Host pass of HIP compiler needs C10_WARP_SIZE defined to _something_ so we // set it to something unreasonable to trigger obvious host code errors. -#if defined(__HIP_DEVICE_COMPILE__) + +namespace at::cuda { +TORCH_CUDA_CPP_API int warp_size(); +} +#ifdef __HIPCC__ +static inline int __host__ C10_WARP_SIZE_INTERNAL() { + return at::cuda::warp_size(); +} + +static inline constexpr int __device__ C10_WARP_SIZE_INTERNAL() { #if defined(__GFX9__) -static constexpr int C10_WARP_SIZE = 64; + return 64; #else // __GFX9__ -static constexpr int C10_WARP_SIZE = 32; + return 32; #endif // __GFX9__ -#else -static constexpr int C10_WARP_SIZE = 1; -#endif // __HIP_DEVICE_COMPILE__ -#else +} +#else // __HIPCC__ +inline int C10_WARP_SIZE_INTERNAL() { + return at::cuda::warp_size(); +} +#endif // __HIPCC__ + +#define C10_WARP_SIZE (C10_WARP_SIZE_INTERNAL()) +#define C10_WARP_SIZE_STATIC 64 + +#else // defined(USE_ROCM) #define C10_WARP_SIZE 32 #endif