Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Add SparseAttention kernel for sm=75 #20531

Merged
merged 5 commits into from May 2, 2024

Conversation

tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented May 1, 2024

Description

Follow up of #20216 to add kernel for sm=75 (GPU like T4, Geforce RTX 2080, GeForce GTX 1650 Ti, NVIDIA TITAN RTX, RTX 4000 etc)

  • Add kernel for sm=75
  • Update dispatch code to use sm to call different kernel.
  • Update compile script to use num_stages=2 instead of 3 for sm=75
  • Refactor test script and add tests for bfloat16.
  • Fix performance test of token generation (previously we did not concatenate past_key)
  • Fix debug build
  • Run performance test and update numbers.

For sm=70, the v1 kernel can be compiled but there is error in compiling v2 kernel. So it is skipped in this pull request.

Performance Test on T4 GPU (using Standard_NC4as_T4_v3 Azure VM) with batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8

We compare sparse attention to corresponding GQA with dense causal. Note that GQA with dense need more computation since no sparsity is used. The TORCH-GQA use naive implementation (using cuSPARSE Block-SpMM could be faster).

prompt-sm75-batch4-head32-d128-local16-vert8-torch.float16:
   sequence_length   TORCH-GQA  ORT-GQA-Dense  ORT-SparseAtt
1             32.0    0.184173       2.994347       0.089064
2             64.0    0.303300       3.023986       0.107418
3            128.0    0.887795       3.073728       0.174213
4            256.0    2.797654       3.246899       0.357869
5            512.0   10.055048       3.814039       0.893903
6           1024.0   37.849937       5.818439       2.658720
7           2048.0  148.641785      13.638480       7.202690
8           4096.0    OOM           43.556847      17.680954
9           8192.0    OOM           161.628540      44.336670

token-sm75-batch4-head32-d128-local16-vert8-torch.float16:
   past_sequence_length  TORCH-GQA  ORT-GQA-Dense  ORT-SparseAtt
1                  32.0   0.110353       2.996305       0.137509
2                  64.0   0.145088       3.006860       0.165424
3                 128.0   0.219500       3.036448       0.192001
4                 256.0   0.347496       3.071341       0.249125
5                 512.0   0.595842       3.135225       0.398726
6                1024.0   1.081216       3.261110       0.612744
7                2048.0   2.060307       3.515578       0.685670
8                4096.0   OOM            4.022986       0.819707
9                8191.0   OOM            5.024528       1.072912

Motivation and Context

To inference Phi-3-small in T4 GPU

@tianleiwu tianleiwu marked this pull request as draft May 1, 2024 17:10
@tianleiwu tianleiwu marked this pull request as ready for review May 1, 2024 21:52
@tianleiwu tianleiwu merged commit 8707655 into main May 2, 2024
94 of 95 checks passed
@tianleiwu tianleiwu deleted the tlwu/sparse_attention_kernel_sm7x branch May 2, 2024 02:52
@sophies927 sophies927 added the triage:approved Approved for cherrypicks for release label May 2, 2024
@yihonglyu yihonglyu added the cherry-picked Cherry-picked for a cherrypicks branch label May 4, 2024
yihonglyu pushed a commit that referenced this pull request May 4, 2024
### Description
Follow up of #20216 to add kernel for sm=75 (GPU like T4, Geforce RTX
2080, GeForce GTX 1650 Ti, NVIDIA TITAN RTX, RTX 4000 etc)

- [x] Add kernel for sm=75
- [x] Update dispatch code to use sm to call different kernel.
- [x] Update compile script to use num_stages=2 instead of 3 for sm=75
- [x] Refactor test script and add tests for bfloat16.
- [x] Fix performance test of token generation (previously we did not
concatenate past_key)
- [x] Fix debug build
- [x] Run performance test and update numbers.

For sm=70, the v1 kernel can be compiled but there is error in compiling
v2 kernel. So it is skipped in this pull request.

Performance Test on T4 GPU (using Standard_NC4as_T4_v3 Azure VM) with
`batch_size=4, num_heads=32, max_seq_len=8192, head_size=128,
sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8`

We compare sparse attention to corresponding GQA with dense causal. Note
that GQA with dense need more computation since no sparsity is used. The
TORCH-GQA use naive implementation (using cuSPARSE Block-SpMM could be
faster).

```
prompt-sm75-batch4-head32-d128-local16-vert8-torch.float16:
   sequence_length   TORCH-GQA  ORT-GQA-Dense  ORT-SparseAtt
1             32.0    0.184173       2.994347       0.089064
2             64.0    0.303300       3.023986       0.107418
3            128.0    0.887795       3.073728       0.174213
4            256.0    2.797654       3.246899       0.357869
5            512.0   10.055048       3.814039       0.893903
6           1024.0   37.849937       5.818439       2.658720
7           2048.0  148.641785      13.638480       7.202690
8           4096.0    OOM           43.556847      17.680954
9           8192.0    OOM           161.628540      44.336670

token-sm75-batch4-head32-d128-local16-vert8-torch.float16:
   past_sequence_length  TORCH-GQA  ORT-GQA-Dense  ORT-SparseAtt
1                  32.0   0.110353       2.996305       0.137509
2                  64.0   0.145088       3.006860       0.165424
3                 128.0   0.219500       3.036448       0.192001
4                 256.0   0.347496       3.071341       0.249125
5                 512.0   0.595842       3.135225       0.398726
6                1024.0   1.081216       3.261110       0.612744
7                2048.0   2.060307       3.515578       0.685670
8                4096.0   OOM            4.022986       0.819707
9                8191.0   OOM            5.024528       1.072912
```

### Motivation and Context

To inference Phi-3-small in T4 GPU
TedThemistokleous pushed a commit to TedThemistokleous/onnxruntime that referenced this pull request May 7, 2024
### Description
Follow up of microsoft#20216 to add kernel for sm=75 (GPU like T4, Geforce RTX
2080, GeForce GTX 1650 Ti, NVIDIA TITAN RTX, RTX 4000 etc)

- [x] Add kernel for sm=75
- [x] Update dispatch code to use sm to call different kernel.
- [x] Update compile script to use num_stages=2 instead of 3 for sm=75
- [x] Refactor test script and add tests for bfloat16.
- [x] Fix performance test of token generation (previously we did not
concatenate past_key)
- [x] Fix debug build
- [x] Run performance test and update numbers.

For sm=70, the v1 kernel can be compiled but there is error in compiling
v2 kernel. So it is skipped in this pull request.

Performance Test on T4 GPU (using Standard_NC4as_T4_v3 Azure VM) with
`batch_size=4, num_heads=32, max_seq_len=8192, head_size=128,
sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8`

We compare sparse attention to corresponding GQA with dense causal. Note
that GQA with dense need more computation since no sparsity is used. The
TORCH-GQA use naive implementation (using cuSPARSE Block-SpMM could be
faster).

```
prompt-sm75-batch4-head32-d128-local16-vert8-torch.float16:
   sequence_length   TORCH-GQA  ORT-GQA-Dense  ORT-SparseAtt
1             32.0    0.184173       2.994347       0.089064
2             64.0    0.303300       3.023986       0.107418
3            128.0    0.887795       3.073728       0.174213
4            256.0    2.797654       3.246899       0.357869
5            512.0   10.055048       3.814039       0.893903
6           1024.0   37.849937       5.818439       2.658720
7           2048.0  148.641785      13.638480       7.202690
8           4096.0    OOM           43.556847      17.680954
9           8192.0    OOM           161.628540      44.336670

token-sm75-batch4-head32-d128-local16-vert8-torch.float16:
   past_sequence_length  TORCH-GQA  ORT-GQA-Dense  ORT-SparseAtt
1                  32.0   0.110353       2.996305       0.137509
2                  64.0   0.145088       3.006860       0.165424
3                 128.0   0.219500       3.036448       0.192001
4                 256.0   0.347496       3.071341       0.249125
5                 512.0   0.595842       3.135225       0.398726
6                1024.0   1.081216       3.261110       0.612744
7                2048.0   2.060307       3.515578       0.685670
8                4096.0   OOM            4.022986       0.819707
9                8191.0   OOM            5.024528       1.072912
```

### Motivation and Context

To inference Phi-3-small in T4 GPU
@yihonglyu yihonglyu added the rel-merged Cherrypicks merged into release label May 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cherry-picked Cherry-picked for a cherrypicks branch rel-merged Cherrypicks merged into release release:1.18.0 triage:approved Approved for cherrypicks for release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants