Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions aten/src/ATen/native/cuda/layer_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ namespace at::native {
namespace {

constexpr int kCUDANumThreads = 256;
#ifdef USE_ROCM
// C10_WARP_SIZE is not constexpr for host code.
#define kWarpSize C10_WARP_SIZE
#else
constexpr unsigned int kWarpSize = C10_WARP_SIZE;
#endif
constexpr int vec_size = 4; //we could make it dependent on dtype, but that would lead to different results between float and low-p types

// aligned vector generates vectorized load/store on CUDA (copy-pasted from MemoryAccess.cuh)
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,11 @@ __global__ void coalesceValuesKernel(
// `if constexpr` when CUDA codes will be compiled under C++-17, see
// gh-56055 for blockers.
template<typename Dtype>
#ifdef USE_ROCM
C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE_STATIC*4)
#else
C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE*4)
#endif
__global__ void coalesceValuesKernel(
int64_t *segment_offsets, int64_t *value_indices,
bool *values, bool *newValues,
Expand Down
30 changes: 23 additions & 7 deletions c10/macros/Macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,16 +318,32 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
// depending on the target device, and then always set it to 64 for host code.
// Host pass of HIP compiler needs C10_WARP_SIZE defined to _something_ so we
// set it to something unreasonable to trigger obvious host code errors.
#if defined(__HIP_DEVICE_COMPILE__)

namespace at::cuda {
TORCH_CUDA_CPP_API int warp_size();
}
#ifdef __HIPCC__
static inline int __host__ C10_WARP_SIZE_INTERNAL() {
return at::cuda::warp_size();
}

static inline constexpr int __device__ C10_WARP_SIZE_INTERNAL() {
#if defined(__GFX9__)
static constexpr int C10_WARP_SIZE = 64;
return 64;
#else // __GFX9__
static constexpr int C10_WARP_SIZE = 32;
return 32;
#endif // __GFX9__
#else
static constexpr int C10_WARP_SIZE = 1;
#endif // __HIP_DEVICE_COMPILE__
#else
}
#else // __HIPCC__
inline int C10_WARP_SIZE_INTERNAL() {
return at::cuda::warp_size();
}
#endif // __HIPCC__

#define C10_WARP_SIZE (C10_WARP_SIZE_INTERNAL())
#define C10_WARP_SIZE_STATIC 64

#else // defined(USE_ROCM)
#define C10_WARP_SIZE 32
#endif

Expand Down