Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions csrc/megatron/scaled_masked_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,6 @@
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
#include <ATen/cuda/CUDAContext.h>
#ifdef USE_ROCM
#if defined(__GFX9__)
#define WARP_SIZE_VALUE 64
#else
#define WARP_SIZE_VALUE 32
#endif
#else
#define WARP_SIZE_VALUE at::cuda::warp_size()
#endif

namespace {

Expand Down Expand Up @@ -447,7 +438,8 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;

int warp_size = (next_power_of_two < WARP_SIZE_VALUE) ? next_power_of_two : WARP_SIZE_VALUE;
int warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size();

int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

constexpr int threads_per_block = 128;
Expand Down Expand Up @@ -476,7 +468,7 @@ void dispatch_scaled_softmax_forward(
int batch_count = batches * attn_heads * query_seq_len;

// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < WARP_SIZE_VALUE) ? next_power_of_two : WARP_SIZE_VALUE;
int warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size();

// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
Expand Down Expand Up @@ -578,7 +570,7 @@ void dispatch_scaled_masked_softmax_forward(
int batch_count = batches * attn_heads * query_seq_len;

// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < WARP_SIZE_VALUE) ? next_power_of_two : WARP_SIZE_VALUE;
int warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size();

// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
Expand Down
13 changes: 2 additions & 11 deletions csrc/megatron/scaled_upper_triang_masked_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,6 @@
#include <stdint.h>
#include <c10/macros/Macros.h>
#include <ATen/cuda/CUDAContext.h>
#ifdef USE_ROCM
#if defined(__GFX9__)
#define WARP_SIZE_VALUE 64
#else
#define WARP_SIZE_VALUE 32
#endif
#else
#define WARP_SIZE_VALUE at::cuda::warp_size()
#endif

namespace {

Expand Down Expand Up @@ -360,7 +351,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int batch_count = attn_batches * seq_len;

// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < WARP_SIZE_VALUE) ? next_power_of_two : WARP_SIZE_VALUE;
int warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size();

// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
Expand Down Expand Up @@ -463,7 +454,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int batch_count = attn_batches * seq_len;

// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < WARP_SIZE_VALUE) ? next_power_of_two : WARP_SIZE_VALUE;
int warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size();

// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
Expand Down