From d761ccf6006880ee3c3a0539bf27b52dfae54cc4 Mon Sep 17 00:00:00 2001 From: Sriram Kumar Date: Thu, 10 Jul 2025 00:51:50 +0300 Subject: [PATCH] Fixing the C10_warpsize issue. replacing the macros with at::cuda::warp_size() (#237) --- csrc/megatron/scaled_masked_softmax.h | 16 ++++------------ .../scaled_upper_triang_masked_softmax.h | 13 ++----------- 2 files changed, 6 insertions(+), 23 deletions(-) diff --git a/csrc/megatron/scaled_masked_softmax.h b/csrc/megatron/scaled_masked_softmax.h index f275ba228..f6e47d0b0 100644 --- a/csrc/megatron/scaled_masked_softmax.h +++ b/csrc/megatron/scaled_masked_softmax.h @@ -24,15 +24,6 @@ #include #include #include -#ifdef USE_ROCM - #if defined(__GFX9__) - #define WARP_SIZE_VALUE 64 - #else - #define WARP_SIZE_VALUE 32 - #endif -#else -#define WARP_SIZE_VALUE at::cuda::warp_size() -#endif namespace { @@ -447,7 +438,8 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att int log2_elements = log2_ceil(key_seq_len); const int next_power_of_two = 1 << log2_elements; - int warp_size = (next_power_of_two < WARP_SIZE_VALUE) ? next_power_of_two : WARP_SIZE_VALUE; + int warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; constexpr int threads_per_block = 128; @@ -476,7 +468,7 @@ void dispatch_scaled_softmax_forward( int batch_count = batches * attn_heads * query_seq_len; // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < WARP_SIZE_VALUE) ? next_power_of_two : WARP_SIZE_VALUE; + int warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -578,7 +570,7 @@ void dispatch_scaled_masked_softmax_forward( int batch_count = batches * attn_heads * query_seq_len; // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < WARP_SIZE_VALUE) ? next_power_of_two : WARP_SIZE_VALUE; + int warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; diff --git a/csrc/megatron/scaled_upper_triang_masked_softmax.h b/csrc/megatron/scaled_upper_triang_masked_softmax.h index e33684dd7..562350af2 100644 --- a/csrc/megatron/scaled_upper_triang_masked_softmax.h +++ b/csrc/megatron/scaled_upper_triang_masked_softmax.h @@ -23,15 +23,6 @@ #include #include #include -#ifdef USE_ROCM - #if defined(__GFX9__) - #define WARP_SIZE_VALUE 64 - #else - #define WARP_SIZE_VALUE 32 - #endif -#else -#define WARP_SIZE_VALUE at::cuda::warp_size() -#endif namespace { @@ -360,7 +351,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( int batch_count = attn_batches * seq_len; // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < WARP_SIZE_VALUE) ? next_power_of_two : WARP_SIZE_VALUE; + int warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -463,7 +454,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( int batch_count = attn_batches * seq_len; // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < WARP_SIZE_VALUE) ? next_power_of_two : WARP_SIZE_VALUE; + int warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size(); // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;