Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Add SparseAttention kernel for sm=75 #20531

Merged
merged 5 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 20 additions & 8 deletions onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ SparseAttention<T>::SparseAttention(const OpKernelInfo& info)

template <typename T>
Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
auto& device_prop = GetDeviceProp();
if constexpr (std::is_same<T, BFloat16>::value) {
if (device_prop.major < 8) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"bfloat16 requires Ampere and above GPUs with Compute Capability >= 8. Got major=",
device_prop.major);
}
}

const Tensor* query = context->Input<Tensor>(0);
const Tensor* key = context->Input<Tensor>(1);
const Tensor* value = context->Input<Tensor>(2);
Expand All @@ -74,8 +83,6 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* cos_cache = context->Input<Tensor>(8);
const Tensor* sin_cache = context->Input<Tensor>(9);

auto& device_prop = GetDeviceProp();

SparseAttentionParameters parameters;

// Parameters from node attribute
Expand All @@ -97,9 +104,9 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
block_mask,
seqlens_k_total,
total_seq_len));

// Some limitations of CUDA kernels
if (!sparse_attention_v1::is_supported_sparse_attention(device_prop)) {
// The v1 and v2 kernels have same coverage, so only check one of them to see whether it is supported.
if (!sparse_attention_v1::is_supported_device(device_prop)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"SparseAttention only support CUDA device with compute capacity 8.*. Got ",
device_prop.major);
Expand All @@ -118,6 +125,9 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {

int past_seq_len = parameters.total_sequence_length - parameters.sequence_length;
bool is_prompt = (past_seq_len == 0);

// The v1 kernel support only prompt and right padding only.
// The v2 kernel support both prompt and token generation, and left/right padding.
bool use_v2_kernel = disable_v1_kernel_ || !is_prompt;

// Async Copy total_k_seq_len from GPU to CPU.
Expand All @@ -139,19 +149,21 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
}

if (!kernel_loaded_) {
int sm = device_prop.major * 10 + device_prop.minor;
if constexpr (std::is_same<T, MLFloat16>::value) {
// std::call_once is used in load_sparse_attention_fp16 so no need to use mutex here.
// After kernel is loaded, it will stay in memory until the process exits. We do not unload explicitly.
// TODO(tianleiwu): use TSharedCubinKernelFactory to manage kernel loading/unloading.
if (use_v2_kernel) {
sparse_attention_v2::load_sparse_attention_fp16();
sparse_attention_v2::load_sparse_attention_fp16(sm);
} else {
sparse_attention_v1::load_sparse_attention_fp16();
sparse_attention_v1::load_sparse_attention_fp16(sm);
}
} else {
if (use_v2_kernel) {
sparse_attention_v2::load_sparse_attention_bf16();
sparse_attention_v2::load_sparse_attention_bf16(sm);
} else {
sparse_attention_v1::load_sparse_attention_bf16();
sparse_attention_v1::load_sparse_attention_bf16(sm);
}
}
kernel_loaded_ = true;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/sparse/sparse_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class SparseAttention final : public CudaKernel {
int sparse_block_size_; // block size for sparsity
bool do_rotary_; // Has rotary positional embedding
bool rotary_interleaved_; // Interleaved rotary positional embedding
bool disable_v1_kernel_; // Disable V2 kernel
bool disable_v1_kernel_; // Whether disable v1 kernel and use v2 kernel for prompt.
mutable bool kernel_loaded_; // Kernel has been loaded
};

Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,11 @@ Status QkvToContext(
data.kernel_layout.num_layout);
#endif

int sm = device_prop.major * 10 + device_prop.minor;
if (data.use_v2_kernel) {
sparse_attention_v2::SparseAttentionParams params(
ort_stream,
sm,
data.output,
reinterpret_cast<const void*>(query),
reinterpret_cast<const void*>(data.present_key),
Expand Down Expand Up @@ -289,6 +291,7 @@ Status QkvToContext(
} else {
sparse_attention_v1::SparseAttentionParams params(
ort_stream,
sm,
data.output,
reinterpret_cast<const void*>(query),
reinterpret_cast<const void*>(data.present_key),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
# --------------------------------------------------------------------------

# Use triton AoT compiler to convert sparse_attention_triton.py to C source files including cubin and dispatcher.
# Example to use this script (Tested with CUDA 12.3 in Ubuntu 20.04):
# python3 -m pip install triton==2.3.0
# Example to use this script (Tested with Python 3.10 and CUDA 12.3 in Ubuntu 20.04):
# python3 -m pip install torch==2.3.0 triton==2.3.0
# python3 compile_sparse_attention.py | sh
#
# Note that sparse_attention_v1_*.cc and sparse_attention_dispatcher_*.h are the generated files.

import math
from itertools import product

import torch

def generate_triton_compile_shell_script(dtype="fp16"):

def generate_triton_compile_shell_script(sm, dtype="fp16"):
assert dtype in ["fp16", "bf16"]
print("export TRITON_ROOT=$(pip show triton | grep Location | cut -d' ' -f2)")
print('export ARCH="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader|head -n 1)"')
print("export SM=$(echo $ARCH | sed -e 's/\\.//g')")

# Modify the compile.py to use custom template file template_h.txt and template_c.txt in current directory.
# Also pass block_m to the template.
Expand Down Expand Up @@ -51,8 +51,8 @@ def generate_triton_compile_shell_script(dtype="fp16"):
scalar_params = "i32,i32,i32,fp32,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32,i32,i32"
sig = f"*{dtype}:16,*{dtype}:16,*{dtype}:16,*{dtype}:16,*i32:16,*i32:16,{scalar_params},{block_m},{int(even_m)},{block_n},{int(even_n)},{block_d},{num_blocks_d}"
prefix = "python compile.py sparse_attention_triton.py"
filename = f"sparse_attention_v1_{dtype}_m{block_m}_{int(even_m)}_n{block_n}_{int(even_n)}_d{block_d}_{num_blocks_d}_sm${{SM}}"
name = f"sparse_attention_{dtype}_sm${{SM}}"
filename = f"sparse_attention_v1_{dtype}_m{block_m}_{int(even_m)}_n{block_n}_{int(even_n)}_d{block_d}_{num_blocks_d}_sm{sm}"
name = f"sparse_attention_{dtype}_sm{sm}"
num_warps = max(1, 2 ** int(math.log2(min(block_m, block_n, block_d) / 16)))
num_stages = 2
# TODO: use different kernel name (change the name in sparse_attention_triton.py before running compile.py)
Expand All @@ -62,7 +62,7 @@ def generate_triton_compile_shell_script(dtype="fp16"):
)

# Generate the dispatcher.
dispatcher = f"sparse_attention_dispatcher_{dtype}_sm${{SM}}"
dispatcher = f"sparse_attention_dispatcher_{dtype}_sm{sm}"
print(f"cd {out_dir}")
print(f"python ${{TRITON_ROOT}}/triton/tools/link.py sparse_attention_v1_*.h -o {dispatcher}")
print("rm *.h")
Expand Down Expand Up @@ -122,5 +122,10 @@ def generate_triton_compile_shell_script(dtype="fp16"):


if __name__ == "__main__":
for dtype in ["fp16", "bf16"]:
generate_triton_compile_shell_script(dtype)
major, minor = torch.cuda.get_device_capability()
print(f"echo Generate sparse attention v1 kernels for compute capability:{major}.{minor}")
assert major >= 7, "triton only supports compute capability >= 7.0"

sm = major * 10 + minor
for dtype in ["fp16", "bf16"] if major >= 8 else ["fp16"]:
generate_triton_compile_shell_script(sm, dtype)
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ namespace sparse_attention_v1 {

struct SparseAttentionParams {
onnxruntime::Stream* ort_stream;
int sm; // compute capability like 80 for A100

void* output;
const void* q;
const void* k;
Expand Down Expand Up @@ -57,6 +59,7 @@ struct SparseAttentionParams {

SparseAttentionParams(
onnxruntime::Stream* ort_stream,
int sm,
void* output,
const void* q,
const void* k,
Expand All @@ -76,6 +79,7 @@ struct SparseAttentionParams {
int layout_col_stride_h,
int num_layout) {
this->ort_stream = ort_stream;
this->sm = sm;
this->output = output;
this->q = q;
this->k = k;
Expand Down