Skip to content

Commit

Permalink
[CUDA] Add SparseAttention kernel for sm=75 (#20531)
Browse files Browse the repository at this point in the history
### Description
Follow up of #20216 to add kernel for sm=75 (GPU like T4, Geforce RTX
2080, GeForce GTX 1650 Ti, NVIDIA TITAN RTX, RTX 4000 etc)

- [x] Add kernel for sm=75
- [x] Update dispatch code to use sm to call different kernel.
- [x] Update compile script to use num_stages=2 instead of 3 for sm=75
- [x] Refactor test script and add tests for bfloat16.
- [x] Fix performance test of token generation (previously we did not
concatenate past_key)
- [x] Fix debug build
- [x] Run performance test and update numbers.

For sm=70, the v1 kernel can be compiled but there is error in compiling
v2 kernel. So it is skipped in this pull request.

Performance Test on T4 GPU (using Standard_NC4as_T4_v3 Azure VM) with
`batch_size=4, num_heads=32, max_seq_len=8192, head_size=128,
sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8`

We compare sparse attention to corresponding GQA with dense causal. Note
that GQA with dense need more computation since no sparsity is used. The
TORCH-GQA use naive implementation (using cuSPARSE Block-SpMM could be
faster).

```
prompt-sm75-batch4-head32-d128-local16-vert8-torch.float16:
   sequence_length   TORCH-GQA  ORT-GQA-Dense  ORT-SparseAtt
1             32.0    0.184173       2.994347       0.089064
2             64.0    0.303300       3.023986       0.107418
3            128.0    0.887795       3.073728       0.174213
4            256.0    2.797654       3.246899       0.357869
5            512.0   10.055048       3.814039       0.893903
6           1024.0   37.849937       5.818439       2.658720
7           2048.0  148.641785      13.638480       7.202690
8           4096.0    OOM           43.556847      17.680954
9           8192.0    OOM           161.628540      44.336670

token-sm75-batch4-head32-d128-local16-vert8-torch.float16:
   past_sequence_length  TORCH-GQA  ORT-GQA-Dense  ORT-SparseAtt
1                  32.0   0.110353       2.996305       0.137509
2                  64.0   0.145088       3.006860       0.165424
3                 128.0   0.219500       3.036448       0.192001
4                 256.0   0.347496       3.071341       0.249125
5                 512.0   0.595842       3.135225       0.398726
6                1024.0   1.081216       3.261110       0.612744
7                2048.0   2.060307       3.515578       0.685670
8                4096.0   OOM            4.022986       0.819707
9                8191.0   OOM            5.024528       1.072912
```

### Motivation and Context

To inference Phi-3-small in T4 GPU
  • Loading branch information
tianleiwu authored and yihonglyu committed May 4, 2024
1 parent 90a2c26 commit 41bf45d
Show file tree
Hide file tree
Showing 24 changed files with 1,374 additions and 300 deletions.
29 changes: 20 additions & 9 deletions onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ SparseAttention<T>::SparseAttention(const OpKernelInfo& info)

template <typename T>
Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
auto& device_prop = GetDeviceProp();
if constexpr (std::is_same<T, BFloat16>::value) {
if (device_prop.major < 8) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"bfloat16 requires Ampere and above GPUs with Compute Capability >= 8. Got major=",
device_prop.major);
}
}

const Tensor* query = context->Input<Tensor>(0);
const Tensor* key = context->Input<Tensor>(1);
const Tensor* value = context->Input<Tensor>(2);
Expand All @@ -74,8 +83,6 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* cos_cache = context->Input<Tensor>(8);
const Tensor* sin_cache = context->Input<Tensor>(9);

auto& device_prop = GetDeviceProp();

SparseAttentionParameters parameters;

// Parameters from node attribute
Expand All @@ -97,9 +104,9 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
block_mask,
seqlens_k_total,
total_seq_len));

// Some limitations of CUDA kernels
if (!sparse_attention_v1::is_supported_sparse_attention(device_prop)) {
// The v1 and v2 kernels have same coverage, so only check one of them to see whether it is supported.
if (!sparse_attention_v1::is_supported_device(device_prop)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"SparseAttention only support CUDA device with compute capacity 8.*. Got ",
device_prop.major);
Expand All @@ -118,6 +125,9 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {

int past_seq_len = parameters.total_sequence_length - parameters.sequence_length;
bool is_prompt = (past_seq_len == 0);

// The v1 kernel support only prompt and right padding only.
// The v2 kernel support both prompt and token generation, and left/right padding.
bool use_v2_kernel = disable_v1_kernel_ || !is_prompt;

// Async Copy total_k_seq_len from GPU to CPU.
Expand All @@ -139,19 +149,21 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
}

if (!kernel_loaded_) {
int sm = device_prop.major * 10 + device_prop.minor;
if constexpr (std::is_same<T, MLFloat16>::value) {
// std::call_once is used in load_sparse_attention_fp16 so no need to use mutex here.
// After kernel is loaded, it will stay in memory until the process exits. We do not unload explicitly.
// TODO(tianleiwu): use TSharedCubinKernelFactory to manage kernel loading/unloading.
if (use_v2_kernel) {
sparse_attention_v2::load_sparse_attention_fp16();
sparse_attention_v2::load_sparse_attention_fp16(sm);
} else {
sparse_attention_v1::load_sparse_attention_fp16();
sparse_attention_v1::load_sparse_attention_fp16(sm);
}
} else {
if (use_v2_kernel) {
sparse_attention_v2::load_sparse_attention_bf16();
sparse_attention_v2::load_sparse_attention_bf16(sm);
} else {
sparse_attention_v1::load_sparse_attention_bf16();
sparse_attention_v1::load_sparse_attention_bf16(sm);
}
}
kernel_loaded_ = true;
Expand All @@ -164,7 +176,6 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
output_shape[2] = static_cast<int64_t>(parameters.hidden_size);
Tensor* output = context->Output(0, output_shape);

assert(parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
std::vector<int64_t> present_dims = {
parameters.batch_size, parameters.kv_num_heads, parameters.max_sequence_length, parameters.head_size};
TensorShape present_shape(present_dims);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/sparse/sparse_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class SparseAttention final : public CudaKernel {
int sparse_block_size_; // block size for sparsity
bool do_rotary_; // Has rotary positional embedding
bool rotary_interleaved_; // Interleaved rotary positional embedding
bool disable_v1_kernel_; // Disable V2 kernel
bool disable_v1_kernel_; // Whether disable v1 kernel and use v2 kernel for prompt.
mutable bool kernel_loaded_; // Kernel has been loaded
};

Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,11 @@ Status QkvToContext(
data.kernel_layout.num_layout);
#endif

int sm = device_prop.major * 10 + device_prop.minor;
if (data.use_v2_kernel) {
sparse_attention_v2::SparseAttentionParams params(
ort_stream,
sm,
data.output,
reinterpret_cast<const void*>(query),
reinterpret_cast<const void*>(data.present_key),
Expand Down Expand Up @@ -289,6 +291,7 @@ Status QkvToContext(
} else {
sparse_attention_v1::SparseAttentionParams params(
ort_stream,
sm,
data.output,
reinterpret_cast<const void*>(query),
reinterpret_cast<const void*>(data.present_key),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
# --------------------------------------------------------------------------

# Use triton AoT compiler to convert sparse_attention_triton.py to C source files including cubin and dispatcher.
# Example to use this script (Tested with CUDA 12.3 in Ubuntu 20.04):
# python3 -m pip install triton==2.3.0
# Example to use this script (Tested with Python 3.10 and CUDA 12.3 in Ubuntu 20.04):
# python3 -m pip install torch==2.3.0 triton==2.3.0
# python3 compile_sparse_attention.py | sh
#
# Note that sparse_attention_v1_*.cc and sparse_attention_dispatcher_*.h are the generated files.

import math
from itertools import product

import torch

def generate_triton_compile_shell_script(dtype="fp16"):

def generate_triton_compile_shell_script(sm, dtype="fp16"):
assert dtype in ["fp16", "bf16"]
print("export TRITON_ROOT=$(pip show triton | grep Location | cut -d' ' -f2)")
print('export ARCH="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader|head -n 1)"')
print("export SM=$(echo $ARCH | sed -e 's/\\.//g')")

# Modify the compile.py to use custom template file template_h.txt and template_c.txt in current directory.
# Also pass block_m to the template.
Expand Down Expand Up @@ -51,8 +51,8 @@ def generate_triton_compile_shell_script(dtype="fp16"):
scalar_params = "i32,i32,i32,fp32,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32,i32,i32"
sig = f"*{dtype}:16,*{dtype}:16,*{dtype}:16,*{dtype}:16,*i32:16,*i32:16,{scalar_params},{block_m},{int(even_m)},{block_n},{int(even_n)},{block_d},{num_blocks_d}"
prefix = "python compile.py sparse_attention_triton.py"
filename = f"sparse_attention_v1_{dtype}_m{block_m}_{int(even_m)}_n{block_n}_{int(even_n)}_d{block_d}_{num_blocks_d}_sm${{SM}}"
name = f"sparse_attention_{dtype}_sm${{SM}}"
filename = f"sparse_attention_v1_{dtype}_m{block_m}_{int(even_m)}_n{block_n}_{int(even_n)}_d{block_d}_{num_blocks_d}_sm{sm}"
name = f"sparse_attention_{dtype}_sm{sm}"
num_warps = max(1, 2 ** int(math.log2(min(block_m, block_n, block_d) / 16)))
num_stages = 2
# TODO: use different kernel name (change the name in sparse_attention_triton.py before running compile.py)
Expand All @@ -62,7 +62,7 @@ def generate_triton_compile_shell_script(dtype="fp16"):
)

# Generate the dispatcher.
dispatcher = f"sparse_attention_dispatcher_{dtype}_sm${{SM}}"
dispatcher = f"sparse_attention_dispatcher_{dtype}_sm{sm}"
print(f"cd {out_dir}")
print(f"python ${{TRITON_ROOT}}/triton/tools/link.py sparse_attention_v1_*.h -o {dispatcher}")
print("rm *.h")
Expand Down Expand Up @@ -122,5 +122,10 @@ def generate_triton_compile_shell_script(dtype="fp16"):


if __name__ == "__main__":
for dtype in ["fp16", "bf16"]:
generate_triton_compile_shell_script(dtype)
major, minor = torch.cuda.get_device_capability()
print(f"echo Generate sparse attention v1 kernels for compute capability:{major}.{minor}")
assert major >= 7, "triton only supports compute capability >= 7.0"

sm = major * 10 + minor
for dtype in ["fp16", "bf16"] if major >= 8 else ["fp16"]:
generate_triton_compile_shell_script(sm, dtype)
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ namespace sparse_attention_v1 {

struct SparseAttentionParams {
onnxruntime::Stream* ort_stream;
int sm; // compute capability like 80 for A100

void* output;
const void* q;
const void* k;
Expand Down Expand Up @@ -57,6 +59,7 @@ struct SparseAttentionParams {

SparseAttentionParams(
onnxruntime::Stream* ort_stream,
int sm,
void* output,
const void* q,
const void* k,
Expand All @@ -76,6 +79,7 @@ struct SparseAttentionParams {
int layout_col_stride_h,
int num_layout) {
this->ort_stream = ort_stream;
this->sm = sm;
this->output = output;
this->q = q;
this->k = k;
Expand Down

0 comments on commit 41bf45d

Please sign in to comment.