Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable cat for nvfuser >= 0.1.7 #35

Merged
merged 9 commits into from
Mar 24, 2024
Merged

Enable cat for nvfuser >= 0.1.7 #35

merged 9 commits into from
Mar 24, 2024

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Mar 21, 2024

Benchmarks are neutral.

Before:

--------------------------------------------------------------------------------------------------------------------- benchmark: 24 tests ----------------------------------------------------------------------------------------------------------------------
Name (time in us)                                                 Min                       Max                      Mean                 StdDev                    Median                    IQR            Outliers          OPS            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_nanogpt_sdpa_fwd[thunder]                                85.9729 (1.0)            123.5039 (1.0)             90.1970 (1.0)           9.0070 (1.45)            87.2805 (1.0)           1.0572 (1.0)           2;3  11,086.8437 (1.0)          20           1
test_nanogpt_layer_norm_fwd[thunder]                          98.7879 (1.15)           236.1580 (1.91)           108.8600 (1.21)         30.0711 (4.83)           101.3121 (1.16)          3.2765 (3.10)          1;2   9,186.1104 (0.83)         20           1
test_nanogpt_gelu_fwd[thunder]                               110.0181 (1.28)           217.0510 (1.76)           115.9017 (1.28)         16.7485 (2.69)           112.1076 (1.28)          3.1011 (2.93)          1;3   8,628.0009 (0.78)         40           1
test_nanogpt_gelu_grad[thunder]                              293.2760 (3.41)           706.9011 (5.72)           324.0489 (3.59)         69.9223 (11.23)          299.2521 (3.43)         15.2696 (14.44)         2;9   3,085.9543 (0.28)         40           1
test_nanogpt_csa_fwd[thunder]                                312.5530 (3.64)           433.2711 (3.51)           320.6752 (3.56)         26.5578 (4.26)           314.2810 (3.60)          2.4339 (2.30)          1;1   3,118.4206 (0.28)         20           1
test_llama2_7b_rmsnorm_grad[thunder]                         413.3140 (4.81)           725.0749 (5.87)           494.8684 (5.49)         73.1271 (11.74)          511.1340 (5.86)        127.2761 (120.39)       20;0   2,020.7393 (0.18)         40           1
test_nanogpt_sdpa_grad[thunder]                              465.0408 (5.41)           629.3731 (5.10)           550.4548 (6.10)         57.3874 (9.22)           584.9096 (6.70)         97.4401 (92.17)         8;0   1,816.6797 (0.16)         20           1
test_nanogpt_mlp_fwd[thunder]                                474.1389 (5.51)           635.3639 (5.14)           480.8633 (5.33)         25.2371 (4.05)           476.1979 (5.46)          1.9741 (1.87)          1;4   2,079.5930 (0.19)         40           1
test_nanogpt_cross_entropy_fwd[thunder]                      617.4808 (7.18)           705.7891 (5.71)           626.5095 (6.95)         18.7819 (3.02)           622.5256 (7.13)          3.8975 (3.69)          1;1   1,596.1449 (0.14)         20           1
test_nanogpt_block_fwd[thunder]                              788.3550 (9.17)         1,079.0159 (8.74)           811.1224 (8.99)         63.5831 (10.21)          795.4536 (9.11)         11.8630 (11.22)         1;2   1,232.8595 (0.11)         20           1
test_nanogpt_csa_grad[thunder]                             1,125.9848 (13.10)        1,563.9060 (12.66)        1,389.8655 (15.41)       159.5754 (25.63)        1,458.0046 (16.70)       286.0019 (270.54)        6;0     719.4941 (0.06)         20           1
test_nanogpt_mlp_grad[thunder]                             1,340.7920 (15.60)        1,653.8859 (13.39)        1,429.6559 (15.85)        87.8710 (14.11)        1,405.1094 (16.10)        48.2167 (45.61)         6;5     699.4690 (0.06)         40           1
test_nanogpt_cross_entropy_grad[thunder]                   1,869.2850 (21.74)        1,897.7679 (15.37)        1,875.7246 (20.80)         6.2271 (1.0)          1,874.9150 (21.48)         6.9386 (6.56)          2;1     533.1273 (0.05)         20           1
test_nanogpt_block_grad[thunder]                           2,490.2120 (28.97)        3,805.0050 (30.81)        2,805.9399 (31.11)       349.0571 (56.05)        2,594.0740 (29.72)       456.2305 (431.56)        3;1     356.3868 (0.03)         20           1
test_nanogpt_gpt2_fwd[thunder]                             6,656.4621 (77.43)        8,000.5289 (64.78)        6,937.3684 (76.91)       594.5807 (95.48)        6,672.7730 (76.45)       367.6593 (347.78)        1;1     144.1469 (0.01)          5           1
test_nanogpt_gpt2_grad[thunder]                           23,743.2919 (276.17)      27,975.8458 (226.52)      25,283.1272 (280.31)    1,732.2901 (278.18)      24,580.3180 (281.62)    2,475.7239 (>1000.0)       1;0      39.5521 (0.00)          5           1
test_nanogpt_gpt2xl_fwd[thunder]                          35,379.9469 (411.52)      38,557.2209 (312.19)      37,649.9868 (417.42)    1,292.4774 (207.56)      38,157.8330 (437.19)    1,098.8793 (>1000.0)       1;1      26.5604 (0.00)          5           1
test_llama2_7b_sdpa_grad[thunder]                         53,339.5000 (620.42)      54,730.9991 (443.15)      53,902.8342 (597.61)      402.7068 (64.67)       53,828.8241 (616.73)      643.8149 (609.00)       15;0      18.5519 (0.00)         40           1
test_nanogpt_gpt2xl_grad[thunder]                        125,010.8320 (>1000.0)    127,470.0339 (>1000.0)    125,858.3189 (>1000.0)     986.0469 (158.35)     125,817.0409 (>1000.0)   1,205.6009 (>1000.0)       1;0       7.9454 (0.00)          5           1
test_llama2_causal_self_attention_7b_grad[thunder]       183,797.7511 (>1000.0)    188,489.4371 (>1000.0)    186,973.9573 (>1000.0)     919.1014 (147.60)     186,991.5576 (>1000.0)   1,137.5169 (>1000.0)      10;1       5.3483 (0.00)         40           1
test_llama2_mlp_7b_grad[thunder]                         236,751.7750 (>1000.0)    242,391.4201 (>1000.0)    239,621.8158 (>1000.0)   1,236.3999 (198.55)     239,807.0445 (>1000.0)   1,620.5839 (>1000.0)      10;0       4.1732 (0.00)         40           1
test_open_llama_7b_fwd[thunder]                        1,134,035.7491 (>1000.0)  1,137,398.5051 (>1000.0)  1,135,597.6682 (>1000.0)   1,525.8574 (245.03)   1,135,803.1400 (>1000.0)   2,821.9949 (>1000.0)       3;0       0.8806 (0.00)          5           1
test_llama_2_7b_grad[thunder]                          1,771,630.3240 (>1000.0)  1,812,210.8860 (>1000.0)  1,794,236.1877 (>1000.0)  11,941.4007 (>1000.0)  1,795,424.9046 (>1000.0)  19,348.8440 (>1000.0)       8;0       0.5573 (0.00)         20           1
test_llama_2_7b_hf_fwd[thunder]                        2,362,912.0821 (>1000.0)  2,380,006.6570 (>1000.0)  2,372,338.9194 (>1000.0)   6,423.7285 (>1000.0)  2,373,629.3749 (>1000.0)   8,465.9495 (>1000.0)       2;0       0.4215 (0.00)          5           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

After:

--------------------------------------------------------------------------------------------------------------------- benchmark: 24 tests ----------------------------------------------------------------------------------------------------------------------
Name (time in us)                                                 Min                       Max                      Mean                 StdDev                    Median                    IQR            Outliers          OPS            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_nanogpt_sdpa_fwd[thunder]                                85.2009 (1.0)            124.4859 (1.0)             88.2313 (1.0)           8.6566 (1.59)            85.8374 (1.0)           0.7616 (1.0)           1;3  11,333.8463 (1.0)          20           1
test_nanogpt_layer_norm_fwd[thunder]                          96.7828 (1.14)           226.2490 (1.82)           106.2040 (1.20)         28.4880 (5.23)            98.4970 (1.15)          5.3196 (6.98)          1;1   9,415.8393 (0.83)         20           1
test_nanogpt_gelu_fwd[thunder]                               109.7580 (1.29)           220.9090 (1.77)           115.8981 (1.31)         18.3704 (3.37)           111.4970 (1.30)          1.8084 (2.37)          2;4   8,628.2722 (0.76)         40           1
test_nanogpt_gelu_grad[thunder]                              281.1830 (3.30)           498.5251 (4.00)           327.3759 (3.71)         48.0104 (8.81)           343.1151 (4.00)         64.7875 (85.07)         2;2   3,054.5923 (0.27)         40           1
test_nanogpt_csa_fwd[thunder]                                310.6389 (3.65)           421.8400 (3.39)           318.4893 (3.61)         24.3820 (4.48)           312.7330 (3.64)          2.3684 (3.11)          1;1   3,139.8229 (0.28)         20           1
test_llama2_7b_rmsnorm_grad[thunder]                         321.2700 (3.77)           556.2950 (4.47)           427.4230 (4.84)         64.2681 (11.80)          406.9416 (4.74)         91.8243 (120.57)       17;0   2,339.6027 (0.21)         40           1
test_nanogpt_sdpa_grad[thunder]                              418.2830 (4.91)           596.5910 (4.79)           497.8039 (5.64)         41.0411 (7.53)           483.9625 (5.64)         21.3555 (28.04)         4;4   2,008.8230 (0.18)         20           1
test_nanogpt_mlp_fwd[thunder]                                475.6320 (5.58)           641.9369 (5.16)           481.3083 (5.46)         26.0898 (4.79)           476.8439 (5.56)          1.2020 (1.58)          1;6   2,077.6705 (0.18)         40           1
test_nanogpt_cross_entropy_fwd[thunder]                      618.6331 (7.26)           918.3811 (7.38)           636.9287 (7.22)         66.2812 (12.17)          621.8040 (7.24)          3.5827 (4.70)          1;1   1,570.0344 (0.14)         20           1
test_nanogpt_block_fwd[thunder]                              787.0132 (9.24)         1,084.0050 (8.71)           806.6165 (9.14)         65.5276 (12.03)          789.6619 (9.20)          7.1742 (9.42)          1;1   1,239.7466 (0.11)         20           1
test_nanogpt_csa_grad[thunder]                             1,164.3781 (13.67)        1,623.8189 (13.04)        1,377.5621 (15.61)       127.9834 (23.50)        1,330.9990 (15.51)       172.1865 (226.09)        8;0     725.9201 (0.06)         20           1
test_nanogpt_mlp_grad[thunder]                             1,330.7731 (15.62)        1,913.2278 (15.37)        1,416.6653 (16.06)       128.9129 (23.67)        1,366.1454 (15.92)        97.8445 (128.47)        6;4     705.8830 (0.06)         40           1
test_nanogpt_cross_entropy_grad[thunder]                   1,867.4510 (21.92)        1,888.5611 (15.17)        1,874.1238 (21.24)         5.4472 (1.0)          1,872.5011 (21.81)         6.8379 (8.98)          7;1     533.5827 (0.05)         20           1
test_nanogpt_block_grad[thunder]                           2,575.9540 (30.23)        4,624.9500 (37.15)        3,135.8434 (35.54)       394.9543 (72.51)        3,112.7925 (36.26)        27.5585 (36.19)         3;6     318.8935 (0.03)         20           1
test_nanogpt_gpt2_fwd[thunder]                             6,583.1831 (77.27)        7,109.2511 (57.11)        6,706.5494 (76.01)       226.2149 (41.53)        6,622.6891 (77.15)       167.7495 (220.26)        1;1     149.1080 (0.01)          5           1
test_nanogpt_gpt2_grad[thunder]                           24,254.6930 (284.68)      27,537.9689 (221.21)      25,215.8198 (285.79)    1,416.5809 (260.06)      24,382.2648 (284.05)    1,813.9615 (>1000.0)       1;0      39.6576 (0.00)          5           1
test_nanogpt_gpt2xl_fwd[thunder]                          36,717.4961 (430.95)      38,749.9100 (311.28)      38,145.8928 (432.34)      861.6864 (158.19)      38,645.1820 (450.21)    1,051.7663 (>1000.0)       1;0      26.2151 (0.00)          5           1
test_llama2_7b_sdpa_grad[thunder]                         53,971.0042 (633.46)      55,367.9529 (444.77)      54,619.9200 (619.05)      403.7169 (74.11)       54,645.0875 (636.61)      733.2700 (962.82)       16;0      18.3083 (0.00)         40           1
test_nanogpt_gpt2xl_grad[thunder]                        127,010.2949 (>1000.0)    129,932.8168 (>1000.0)    128,392.6164 (>1000.0)   1,271.2247 (233.37)     128,682.6360 (>1000.0)   2,229.3555 (>1000.0)       2;0       7.7886 (0.00)          5           1
test_llama2_causal_self_attention_7b_grad[thunder]       180,725.4220 (>1000.0)    185,217.8711 (>1000.0)    183,897.6699 (>1000.0)     912.1753 (167.46)     183,977.0535 (>1000.0)   1,265.6110 (>1000.0)      12;1       5.4378 (0.00)         40           1
test_llama2_mlp_7b_grad[thunder]                         238,051.7379 (>1000.0)    243,403.6671 (>1000.0)    240,681.5872 (>1000.0)   1,245.0415 (228.57)     240,967.0515 (>1000.0)   1,819.9860 (>1000.0)      11;0       4.1549 (0.00)         40           1
test_open_llama_7b_fwd[thunder]                        1,083,434.2770 (>1000.0)  1,085,535.2401 (>1000.0)  1,084,531.6276 (>1000.0)     809.0349 (148.52)   1,084,406.5200 (>1000.0)   1,167.9122 (>1000.0)       2;0       0.9221 (0.00)          5           1
test_llama_2_7b_grad[thunder]                          1,759,139.5450 (>1000.0)  1,798,621.3120 (>1000.0)  1,781,164.2416 (>1000.0)  11,620.1781 (>1000.0)  1,782,893.3335 (>1000.0)  20,682.9539 (>1000.0)       8;0       0.5614 (0.00)         20           1
test_llama_2_7b_hf_fwd[thunder]                        2,260,459.6941 (>1000.0)  2,275,246.8339 (>1000.0)  2,268,166.2148 (>1000.0)   5,800.4606 (>1000.0)  2,268,238.2530 (>1000.0)   8,989.1682 (>1000.0)       2;0       0.4409 (0.00)          5           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Benchmarked on the following GPU:

$ nvidia-smi -L
GPU 0: NVIDIA A100 80GB PCIe (UUID: GPU-af301991-abb4-a8db-ec48-be7abb9c59b3)

@wujingyue
Copy link
Collaborator

Are you going to fix the PR description? Otherwise LGTM!

@wujingyue wujingyue changed the title Enable cat for nvfuser >= 0.1.7 (PR1844) Enable cat for nvfuser >= 0.1.7 Mar 21, 2024
Fixes the failing test_cse_rematerialization
@jacobhinkle jacobhinkle enabled auto-merge (squash) March 21, 2024 17:45
@wujingyue
Copy link
Collaborator

I don't understand why "docs-make / make-docs" is still pending, but this should be ready to merge.

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you @jacobhinkle @wujingyue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants