Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Embedding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,

int warp_size = at::cuda::warp_size();
TORCH_INTERNAL_ASSERT(num_threads() % warp_size == 0 &&
num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads,
num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads(),
"BlockReduceSum requires all warps be active");
const int64_t *num_unique_indices_ptr = num_unique_indices.const_data_ptr<int64_t>();
dim3 grid = unique_indices.numel();
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/MultinomialKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ void renormRows(Tensor& t) {
TORCH_CHECK(props != nullptr);
int numSM = props->multiProcessorCount;
const int64_t maxThreads = std::min(
props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads);
props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads());

int warp_size = at::cuda::warp_size();
dim3 grid(rows < numSM * 4 ? rows : numSM * 4);
Expand Down
11 changes: 6 additions & 5 deletions aten/src/ATen/native/cuda/SoftMax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,16 @@ inline dim3 SoftMaxForward_getBlockSize(uint64_t dim_size) {
uint64_t block_size = 1;
uint64_t max_block_size = std::min(dim_size, static_cast<uint64_t>(max_threads));

// We need a block size that is a multiple of C10_WARP_SIZE in order
// We need a block size that is a multiple of at::cuda::warp_size() in order
// to perform block size reductions using warp shuffle instructions.
// Since max_threads is also a multiple of C10_WARPS_SIZE we do not
// risk creating a block size larger than the limit.

if (max_block_size % C10_WARP_SIZE == 0) {
int warp_size = at::cuda::warp_size();
if (max_block_size % warp_size == 0) {
block_size = max_block_size;
} else {
block_size = (max_block_size / C10_WARP_SIZE + 1) * C10_WARP_SIZE;
block_size = (max_block_size / warp_size + 1) * warp_size;
}

return dim3(block_size);
Expand Down Expand Up @@ -859,7 +860,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
} else {
constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
dim3 block = SoftMaxForward_getBlockSize(dim_size);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t);
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
smem_reduction_sz) / sizeof(scalar_t);

Expand Down Expand Up @@ -895,7 +896,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
} else {
constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
dim3 block = SoftMaxForward_getBlockSize(dim_size);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t);
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
smem_reduction_sz) / sizeof(scalar_t);

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/TensorModeKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ void handle_fused_mode(
constexpr int num_threads = size / 2;
int warp_size = at::cuda::warp_size();
TORCH_INTERNAL_ASSERT(num_threads % warp_size == 0 &&
num_threads <= cuda_utils::kCUDABlockReduceMaxThreads, "");
num_threads <= cuda_utils::kCUDABlockReduceMaxThreads(), "");
const auto memsize =
(sizeof(scalar_t) * size) + (2 * size * sizeof(unsigned int));
compute_mode<scalar_t, size>
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/cuda/TensorTopK.cu
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,12 @@ __global__ void computeBlockwiseWithinKCounts(
warp_counts[warp] = count;
}
__syncthreads();
#ifdef USE_ROCM
CUDA_KERNEL_ASSERT(RADIX_DIGITS < C10_WARP_SIZE * C10_WARP_SIZE);
#else
static_assert(RADIX_DIGITS < C10_WARP_SIZE * C10_WARP_SIZE,
"Assuming only 1 warp is needed for final reduction");
#endif
if (warp != 0) {
return;
}
Expand Down
12 changes: 11 additions & 1 deletion aten/src/ATen/native/cuda/block_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,17 @@ constexpr int kCUDABlockReduceNumThreads = 512;
// of which reduces C10_WARP_SIZE elements. So, at most
// C10_WARP_SIZE**2 elements can be reduced at a time.
// NOTE: This is >= the max block size on current hardware anyway (1024).
constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE;
// ROCm NOTE: C10_WARP_SIZE should only be used inside device functions,
// and kCUDABlockReduceMaxThreads is a host-side variable.
#ifdef USE_ROCM
static int kCUDABlockReduceMaxThreads() {
return at::cuda::warp_size() * at::cuda::warp_size();
}
#else
constexpr int kCUDABlockReduceMaxThreads() {
return C10_WARP_SIZE * C10_WARP_SIZE;
}
#endif

// Sums `val` across all threads in a warp.
//
Expand Down
16 changes: 15 additions & 1 deletion c10/macros/Macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,21 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
#endif

#if defined(USE_ROCM)
#define C10_WARP_SIZE warpSize // = 64 or 32 (Defined in hip_runtime.h)
// C10_WARP_SIZE is only allowed for device code.
// Host code _must_ use at::cuda::warp_size()
// HIP header used to define warpSize as a constexpr that was either 32 or 64
// depending on the target device, and then always set it to 64 for host code.
// Host pass of HIP compiler needs C10_WARP_SIZE defined to _something_ so we
// set it to something unreasonable to trigger obvious host code errors.
#if defined(__HIP_DEVICE_COMPILE__)
#if defined(__GFX9__)
static constexpr int C10_WARP_SIZE = 64;
#else // __GFX9__
static constexpr int C10_WARP_SIZE = 32;
#endif // __GFX9__
#else
static constexpr int C10_WARP_SIZE = 1;
#endif // __HIP_DEVICE_COMPILE__
#else
#define C10_WARP_SIZE 32
#endif
Expand Down
2 changes: 1 addition & 1 deletion third_party/composable_kernel
6 changes: 3 additions & 3 deletions torch/csrc/distributed/c10d/CUDASymmetricMemory.cu
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ static __global__ void barrier_kernel(
void CUDASymmetricMemory::barrier(int channel) {
check_channel(channel, world_size_);
c10::cuda::CUDAGuard guard(local_device_idx_);
barrier_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
barrier_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<uint32_t**>(signal_pads_dev_),
channel,
rank_,
Expand All @@ -458,7 +458,7 @@ static __global__ void put_signal_kernel(
void CUDASymmetricMemory::put_signal(int dst_rank, int channel) {
check_channel(channel, world_size_);
c10::cuda::CUDAGuard guard(local_device_idx_);
put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
put_signal_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<uint32_t**>(signal_pads_dev_),
dst_rank,
channel,
Expand All @@ -482,7 +482,7 @@ static __global__ void wait_signal_kernel(
void CUDASymmetricMemory::wait_signal(int src_rank, int channel) {
check_channel(channel, world_size_);
c10::cuda::CUDAGuard guard(local_device_idx_);
wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
wait_signal_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<uint32_t**>(signal_pads_dev_),
src_rank,
channel,
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void init_elementwise_launch_config(
num_blocks = 1;
num_threads = at::round_up(
at::ceil_div(numel_per_split, numel_per_thread),
static_cast<size_t>(C10_WARP_SIZE));
static_cast<size_t>(at::cuda::warp_size()));
} else {
num_blocks = std::min(
at::ceil_div(
Expand Down
6 changes: 6 additions & 0 deletions torch/testing/_internal/common_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def CDNA2OrLater():
def evaluate_platform_supports_flash_attention():
if TEST_WITH_ROCM:
arch_list = ["gfx90a", "gfx942", "gfx1100"]
version = _get_torch_rocm_version()
if version >= (6, 5):
arch_list += ["gfx950"]
return evaluate_gfx_arch_within(arch_list)
if TEST_CUDA:
return not IS_WINDOWS and SM80OrLater
Expand All @@ -56,6 +59,9 @@ def evaluate_platform_supports_flash_attention():
def evaluate_platform_supports_efficient_attention():
if TEST_WITH_ROCM:
arch_list = ["gfx90a", "gfx942", "gfx1100"]
version = _get_torch_rocm_version()
if version >= (6, 5):
arch_list += ["gfx950"]
return evaluate_gfx_arch_within(arch_list)
if TEST_CUDA:
return True
Expand Down