Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SDPA benchmark for Llama 2 13B, Llama 2 70B, Llama 3 8B, Llama 3 70B #276

Merged
merged 9 commits into from
Apr 25, 2024

Conversation

IvanYashchuk
Copy link
Collaborator

@IvanYashchuk IvanYashchuk commented Apr 25, 2024

New parametrized microbenchmark to test the performance of several configurations of the scaled dot product attention call:

pytest thunder/benchmarks/targets.py -k "test_litgpt_sdpa_grad" --collect-only
<Module thunder/benchmarks/targets.py>
  <Function test_litgpt_sdpa_grad[Llama-2-7b-hf-torch]>
  <Function test_litgpt_sdpa_grad[Llama-2-7b-hf-torch.compile]>
  <Function test_litgpt_sdpa_grad[Llama-2-7b-hf-thunder]>
  <Function test_litgpt_sdpa_grad[Llama-2-7b-hf-thunder+cudnn]>
  <Function test_litgpt_sdpa_grad[Llama-2-13b-hf-torch]>
  <Function test_litgpt_sdpa_grad[Llama-2-13b-hf-torch.compile]>
  <Function test_litgpt_sdpa_grad[Llama-2-13b-hf-thunder]>
  <Function test_litgpt_sdpa_grad[Llama-2-13b-hf-thunder+cudnn]>
  <Function test_litgpt_sdpa_grad[Llama-2-70b-hf-torch]>
  <Function test_litgpt_sdpa_grad[Llama-2-70b-hf-torch.compile]>
  <Function test_litgpt_sdpa_grad[Llama-2-70b-hf-thunder]>
  <Function test_litgpt_sdpa_grad[Llama-2-70b-hf-thunder+cudnn]>
  <Function test_litgpt_sdpa_grad[Llama-3-8B-torch]>
  <Function test_litgpt_sdpa_grad[Llama-3-8B-torch.compile]>
  <Function test_litgpt_sdpa_grad[Llama-3-8B-thunder]>
  <Function test_litgpt_sdpa_grad[Llama-3-8B-thunder+cudnn]>
  <Function test_litgpt_sdpa_grad[Llama-3-70B-torch]>
  <Function test_litgpt_sdpa_grad[Llama-3-70B-torch.compile]>
  <Function test_litgpt_sdpa_grad[Llama-3-70B-thunder]>
  <Function test_litgpt_sdpa_grad[Llama-3-70B-thunder+cudnn]>

H100 results:

------------------------------------------------- benchmark 'config=Llama-2-13b-hf bs=1': 4 tests --------------------------------------------------
Name (time in ms)                                              Min               Max              Mean            StdDev            Median          
----------------------------------------------------------------------------------------------------------------------------------------------------
test_litgpt_sdpa_grad[Llama-2-13b-hf-bs1-torch]             3.0661 (1.0)      3.1379 (1.0)      3.1114 (1.0)      0.0178 (1.0)      3.1134 (1.0)    
test_litgpt_sdpa_grad[Llama-2-13b-hf-bs1-torch.compile]     3.1149 (1.02)     3.2705 (1.04)     3.2171 (1.03)     0.0337 (1.89)     3.2256 (1.04)   
test_litgpt_sdpa_grad[Llama-2-13b-hf-bs1-thunder+cudnn]     3.1748 (1.04)     3.5941 (1.15)     3.3346 (1.07)     0.0904 (5.07)     3.3597 (1.08)   
test_litgpt_sdpa_grad[Llama-2-13b-hf-bs1-thunder]           3.2333 (1.05)     6.5686 (2.09)     3.3775 (1.09)     0.5191 (29.13)    3.2828 (1.05)   
----------------------------------------------------------------------------------------------------------------------------------------------------

-------------------------------------------------- benchmark 'config=Llama-2-13b-hf bs=2': 4 tests --------------------------------------------------
Name (time in ms)                                              Min                Max              Mean            StdDev            Median          
-----------------------------------------------------------------------------------------------------------------------------------------------------
test_litgpt_sdpa_grad[Llama-2-13b-hf-bs2-torch]             5.7707 (1.0)      13.1327 (2.16)     6.0389 (1.01)     1.1518 (16.65)    5.8793 (1.0)    
test_litgpt_sdpa_grad[Llama-2-13b-hf-bs2-thunder+cudnn]     5.8385 (1.01)      6.2097 (1.02)     6.0195 (1.00)     0.1196 (1.73)     6.0925 (1.04)   
test_litgpt_sdpa_grad[Llama-2-13b-hf-bs2-thunder]           5.8390 (1.01)      6.1858 (1.02)     6.0104 (1.00)     0.1172 (1.69)     6.0403 (1.03)   
test_litgpt_sdpa_grad[Llama-2-13b-hf-bs2-torch.compile]     5.8466 (1.01)      6.0717 (1.0)      6.0024 (1.0)      0.0692 (1.0)      6.0319 (1.03)   
-----------------------------------------------------------------------------------------------------------------------------------------------------

------------------------------------------------- benchmark 'config=Llama-2-70b-hf bs=1': 4 tests --------------------------------------------------
Name (time in ms)                                              Min               Max              Mean            StdDev            Median          
----------------------------------------------------------------------------------------------------------------------------------------------------
test_litgpt_sdpa_grad[Llama-2-70b-hf-bs1-torch]             4.6588 (1.0)      4.7111 (1.0)      4.6828 (1.0)      0.0122 (1.0)      4.6838 (1.0)    
test_litgpt_sdpa_grad[Llama-2-70b-hf-bs1-thunder]           4.7134 (1.01)     5.0908 (1.08)     4.9072 (1.05)     0.1150 (9.39)     4.9086 (1.05)   
test_litgpt_sdpa_grad[Llama-2-70b-hf-bs1-torch.compile]     4.7558 (1.02)     4.9188 (1.04)     4.8806 (1.04)     0.0251 (2.05)     4.8840 (1.04)   
test_litgpt_sdpa_grad[Llama-2-70b-hf-bs1-thunder+cudnn]     4.7860 (1.03)     6.6831 (1.42)     4.9499 (1.06)     0.2947 (24.08)    4.9545 (1.06)   
----------------------------------------------------------------------------------------------------------------------------------------------------

-------------------------------------------------- benchmark 'config=Llama-2-70b-hf bs=2': 4 tests --------------------------------------------------
Name (time in ms)                                              Min                Max              Mean            StdDev            Median          
-----------------------------------------------------------------------------------------------------------------------------------------------------
test_litgpt_sdpa_grad[Llama-2-70b-hf-bs2-torch]             8.9816 (1.0)       9.0927 (1.0)      9.0327 (1.0)      0.0235 (1.0)      9.0262 (1.0)    
test_litgpt_sdpa_grad[Llama-2-70b-hf-bs2-torch.compile]     9.0140 (1.00)      9.5012 (1.04)     9.2556 (1.02)     0.1917 (8.14)     9.2814 (1.03)   
test_litgpt_sdpa_grad[Llama-2-70b-hf-bs2-thunder+cudnn]     9.0479 (1.01)     28.5899 (3.14)     9.7804 (1.08)     3.0640 (130.17)   9.1992 (1.02)   
test_litgpt_sdpa_grad[Llama-2-70b-hf-bs2-thunder]           9.0577 (1.01)      9.7343 (1.07)     9.2643 (1.03)     0.2303 (9.78)     9.1387 (1.01)   
-----------------------------------------------------------------------------------------------------------------------------------------------------

-------------------------------------------------- benchmark 'config=Llama-2-7b-hf bs=1': 4 tests --------------------------------------------------
Name (time in ms)                                             Min                Max              Mean            StdDev            Median          
----------------------------------------------------------------------------------------------------------------------------------------------------
test_litgpt_sdpa_grad[Llama-2-7b-hf-bs1-torch]             2.5610 (1.0)       2.8007 (1.0)      2.6968 (1.0)      0.0616 (2.22)     2.7176 (1.01)   
test_litgpt_sdpa_grad[Llama-2-7b-hf-bs1-thunder+cudnn]     2.6426 (1.03)     15.2240 (5.44)     3.1441 (1.17)     2.0792 (74.79)    2.7321 (1.01)   
test_litgpt_sdpa_grad[Llama-2-7b-hf-bs1-torch.compile]     2.6584 (1.04)      2.8343 (1.01)     2.7140 (1.01)     0.0426 (1.53)     2.7039 (1.0)    
test_litgpt_sdpa_grad[Llama-2-7b-hf-bs1-thunder]           2.6909 (1.05)      2.8335 (1.01)     2.7428 (1.02)     0.0278 (1.0)      2.7421 (1.01)   
----------------------------------------------------------------------------------------------------------------------------------------------------

-------------------------------------------------- benchmark 'config=Llama-2-7b-hf bs=2': 4 tests -------------------------------------------------
Name (time in ms)                                             Min               Max              Mean            StdDev            Median          
---------------------------------------------------------------------------------------------------------------------------------------------------
test_litgpt_sdpa_grad[Llama-2-7b-hf-bs2-torch]             4.7210 (1.0)      5.3637 (1.08)     4.8691 (1.00)     0.0848 (1.0)      4.8589 (1.01)   
test_litgpt_sdpa_grad[Llama-2-7b-hf-bs2-torch.compile]     4.7275 (1.00)     4.9739 (1.0)      4.8538 (1.0)      0.0913 (1.08)     4.8042 (1.0)    
test_litgpt_sdpa_grad[Llama-2-7b-hf-bs2-thunder]           4.7284 (1.00)     5.1364 (1.03)     4.9092 (1.01)     0.1227 (1.45)     4.8489 (1.01)   
test_litgpt_sdpa_grad[Llama-2-7b-hf-bs2-thunder+cudnn]     4.7709 (1.01)     5.0464 (1.01)     4.9007 (1.01)     0.0870 (1.03)     4.9477 (1.03)   
---------------------------------------------------------------------------------------------------------------------------------------------------

---------------------------------------------------- benchmark 'config=Llama-3-70B bs=1': 4 tests ---------------------------------------------------
Name (time in ms)                                            Min                Max               Mean            StdDev             Median          
-----------------------------------------------------------------------------------------------------------------------------------------------------
test_litgpt_sdpa_grad[Llama-3-70B-bs1-torch]             16.4030 (1.0)      16.5437 (1.0)      16.4807 (1.0)      0.0365 (1.0)      16.4812 (1.0)    
test_litgpt_sdpa_grad[Llama-3-70B-bs1-torch.compile]     16.4350 (1.00)     17.2223 (1.04)     16.6717 (1.01)     0.1999 (5.48)     16.6009 (1.01)   
test_litgpt_sdpa_grad[Llama-3-70B-bs1-thunder+cudnn]     16.4777 (1.00)     17.5759 (1.06)     16.7712 (1.02)     0.3402 (9.33)     16.6264 (1.01)   
test_litgpt_sdpa_grad[Llama-3-70B-bs1-thunder]           16.4842 (1.00)     17.6872 (1.07)     16.7536 (1.02)     0.3438 (9.43)     16.5939 (1.01)   
-----------------------------------------------------------------------------------------------------------------------------------------------------

---------------------------------------------------- benchmark 'config=Llama-3-70B bs=2': 4 tests ---------------------------------------------------
Name (time in ms)                                            Min                Max               Mean            StdDev             Median          
-----------------------------------------------------------------------------------------------------------------------------------------------------
test_litgpt_sdpa_grad[Llama-3-70B-bs2-torch]             32.2738 (1.0)      32.4355 (1.0)      32.3522 (1.0)      0.0370 (1.0)      32.3547 (1.0)    
test_litgpt_sdpa_grad[Llama-3-70B-bs2-thunder]           32.3163 (1.00)     35.1603 (1.08)     32.6140 (1.01)     0.5682 (15.36)    32.4356 (1.00)   
test_litgpt_sdpa_grad[Llama-3-70B-bs2-torch.compile]     32.3179 (1.00)     33.8013 (1.04)     32.5951 (1.01)     0.4400 (11.89)    32.4149 (1.00)   
test_litgpt_sdpa_grad[Llama-3-70B-bs2-thunder+cudnn]     32.3292 (1.00)     34.3968 (1.06)     32.6055 (1.01)     0.5089 (13.75)    32.4181 (1.00)   
-----------------------------------------------------------------------------------------------------------------------------------------------------

-------------------------------------------------- benchmark 'config=Llama-3-8B bs=1': 4 tests --------------------------------------------------
Name (time in ms)                                          Min                Max              Mean            StdDev            Median          
-------------------------------------------------------------------------------------------------------------------------------------------------
test_litgpt_sdpa_grad[Llama-3-8B-bs1-torch]             8.5236 (1.0)       8.7133 (1.0)      8.6108 (1.0)      0.0474 (1.0)      8.6179 (1.0)    
test_litgpt_sdpa_grad[Llama-3-8B-bs1-thunder+cudnn]     8.5499 (1.00)     10.9637 (1.26)     8.8771 (1.03)     0.4501 (9.50)     8.7242 (1.01)   
test_litgpt_sdpa_grad[Llama-3-8B-bs1-thunder]           8.5540 (1.00)      9.3390 (1.07)     8.8146 (1.02)     0.2570 (5.42)     8.7282 (1.01)   
test_litgpt_sdpa_grad[Llama-3-8B-bs1-torch.compile]     8.6270 (1.01)      9.1252 (1.05)     8.8826 (1.03)     0.1784 (3.76)     8.9499 (1.04)   
-------------------------------------------------------------------------------------------------------------------------------------------------

--------------------------------------------------- benchmark 'config=Llama-3-8B bs=2': 4 tests ----------------------------------------------------
Name (time in ms)                                           Min                Max               Mean            StdDev             Median          
----------------------------------------------------------------------------------------------------------------------------------------------------
test_litgpt_sdpa_grad[Llama-3-8B-bs2-torch]             16.4562 (1.0)      16.5750 (1.0)      16.5125 (1.0)      0.0295 (1.0)      16.5139 (1.0)    
test_litgpt_sdpa_grad[Llama-3-8B-bs2-thunder+cudnn]     16.4659 (1.00)     17.7330 (1.07)     16.7694 (1.02)     0.3560 (12.06)    16.6258 (1.01)   
test_litgpt_sdpa_grad[Llama-3-8B-bs2-torch.compile]     16.4691 (1.00)     17.1162 (1.03)     16.7076 (1.01)     0.2045 (6.93)     16.6341 (1.01)   
test_litgpt_sdpa_grad[Llama-3-8B-bs2-thunder]           16.4893 (1.00)     23.6173 (1.42)     16.9235 (1.02)     1.1322 (38.37)    16.6114 (1.01)   
----------------------------------------------------------------------------------------------------------------------------------------------------

I have updated the litgpt pin to include the Llama 3 config and added it to the general jit tests.

cc @crcrpar

@crcrpar crcrpar mentioned this pull request Apr 25, 2024
Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks

@carmocca carmocca merged commit de612bd into main Apr 25, 2024
42 checks passed
@carmocca carmocca deleted the more-sdpa-benchmarks branch April 25, 2024 14:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants