Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Lowering for FlexAttention Backwards (pytorch#125515)
# Summary #### What does this PR do? It enables Inductor to actually generate the fused flex attention kernel for the backwards I did some other things along the way: - Abstract out the 'build_subgraph_buffer' subroutine and make it reusable between flex attention and flex_attention backwards. In total we need too build 3 subgraphs for fwd + bwd. 1 for the fwd graph and then 2 in the bwd. The FAv2 algorithm recomputes the parts of the forward (more efficiently since we already have the row_max via logsumexp), therefore we need to inline both the fwd graph and the joint graph in the bwds kernel. - The version of the backwards kernel is from a somewhat older version of the triton tutorial implementation. I think that we should update in a follow up to a newer version. Notably the blocks need to be square for this to work as currently implemented. I am sure there are many opportunities for optimization. - I didnt correctly register the decomp table + IndexMode when I landed: pytorch#123902, this remedies that. - The rel_bias helper func was reversed in terms of causality. I updated and then add a test specific for "future causal" attention. - This PRs but the main point that I think still needs to be worked out is the store_output call. I have it hacked up to be 'fake' but I dont think we want to land that and likely want to just have a mutated 'dq' and a stored_output 'dk' - I also needed to update the `TritonTemplateKernel` to actually accept multiple subgraphs (modifications) - I updated the benchmark to also profile bwds performance ### Benchmark Numbers: _The current implementation is not parallelizing over ctx length in the bwd_ FWD Speedups | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|-------------|----------------| | Average | 0.991 | | | | | Max | 1.182 | (16, 16, 4096, 64) | noop | torch.bfloat16 | | Min | 0.796 | (2, 16, 512, 256) | head_bias | torch.bfloat16 | BWD Speedups | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|-------------|----------------| | Average | 0.291 | | | | | Max | 0.652 | (8, 16, 512, 64) | head_bias | torch.bfloat16 | | Min | 0.073 | (2, 16, 4096, 128) | head_bias | torch.bfloat16 | <details> <summary>Full Data</summary> | shape | score_mod | dtype | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup | |---------------------|---------------|----------------|------------------|---------------------|------------------|---------------------|---------------|---------------| | (2, 16, 512, 64) | noop | torch.bfloat16 | 19.936 | 19.092 | 57.851 | 193.564 | 1.044 | 0.299 | | (2, 16, 512, 64) | causal_mask | torch.bfloat16 | 19.955 | 19.497 | 57.662 | 206.278 | 1.024 | 0.280 | | (2, 16, 512, 64) | relative_bias | torch.bfloat16 | 19.455 | 21.297 | 57.674 | 195.219 | 0.913 | 0.295 | | (2, 16, 512, 64) | head_bias | torch.bfloat16 | 19.958 | 21.289 | 57.674 | 193.859 | 0.938 | 0.298 | | (2, 16, 512, 128) | noop | torch.bfloat16 | 28.157 | 28.615 | 82.831 | 454.211 | 0.984 | 0.182 | | (2, 16, 512, 128) | causal_mask | torch.bfloat16 | 28.154 | 28.444 | 83.091 | 432.083 | 0.990 | 0.192 | | (2, 16, 512, 128) | relative_bias | torch.bfloat16 | 28.722 | 27.897 | 83.175 | 446.789 | 1.030 | 0.186 | | (2, 16, 512, 128) | head_bias | torch.bfloat16 | 28.299 | 27.673 | 83.052 | 459.179 | 1.023 | 0.181 | | (2, 16, 512, 256) | noop | torch.bfloat16 | 41.167 | 50.504 | 175.019 | 1083.545 | 0.815 | 0.162 | | (2, 16, 512, 256) | causal_mask | torch.bfloat16 | 41.656 | 51.933 | 175.078 | 1171.176 | 0.802 | 0.149 | | (2, 16, 512, 256) | relative_bias | torch.bfloat16 | 41.697 | 50.722 | 175.159 | 1097.312 | 0.822 | 0.160 | | (2, 16, 512, 256) | head_bias | torch.bfloat16 | 41.690 | 52.387 | 175.184 | 1097.336 | 0.796 | 0.160 | | (2, 16, 1024, 64) | noop | torch.bfloat16 | 39.232 | 37.454 | 127.847 | 612.430 | 1.047 | 0.209 | | (2, 16, 1024, 64) | causal_mask | torch.bfloat16 | 39.930 | 39.599 | 127.755 | 665.359 | 1.008 | 0.192 | | (2, 16, 1024, 64) | relative_bias | torch.bfloat16 | 39.417 | 41.304 | 127.902 | 614.990 | 0.954 | 0.208 | | (2, 16, 1024, 64) | head_bias | torch.bfloat16 | 39.965 | 42.034 | 127.953 | 613.273 | 0.951 | 0.209 | | (2, 16, 1024, 128) | noop | torch.bfloat16 | 63.964 | 71.024 | 226.510 | 1637.669 | 0.901 | 0.138 | | (2, 16, 1024, 128) | causal_mask | torch.bfloat16 | 63.843 | 72.451 | 226.750 | 1558.949 | 0.881 | 0.145 | | (2, 16, 1024, 128) | relative_bias | torch.bfloat16 | 64.301 | 70.487 | 226.651 | 1610.063 | 0.912 | 0.141 | | (2, 16, 1024, 128) | head_bias | torch.bfloat16 | 64.033 | 71.394 | 226.676 | 1668.511 | 0.897 | 0.136 | | (2, 16, 1024, 256) | noop | torch.bfloat16 | 129.348 | 141.390 | 507.337 | 4405.175 | 0.915 | 0.115 | | (2, 16, 1024, 256) | causal_mask | torch.bfloat16 | 129.538 | 145.680 | 507.178 | 4768.874 | 0.889 | 0.106 | | (2, 16, 1024, 256) | relative_bias | torch.bfloat16 | 129.438 | 142.782 | 507.004 | 4401.002 | 0.907 | 0.115 | | (2, 16, 1024, 256) | head_bias | torch.bfloat16 | 129.058 | 146.242 | 507.547 | 4434.251 | 0.883 | 0.114 | | (2, 16, 4096, 64) | noop | torch.bfloat16 | 481.606 | 409.120 | 1440.890 | 14147.269 | 1.177 | 0.102 | | (2, 16, 4096, 64) | causal_mask | torch.bfloat16 | 480.227 | 438.847 | 1434.419 | 14973.386 | 1.094 | 0.096 | | (2, 16, 4096, 64) | relative_bias | torch.bfloat16 | 480.831 | 458.104 | 1432.935 | 14193.253 | 1.050 | 0.101 | | (2, 16, 4096, 64) | head_bias | torch.bfloat16 | 480.749 | 452.497 | 1437.040 | 14084.869 | 1.062 | 0.102 | | (2, 16, 4096, 128) | noop | torch.bfloat16 | 872.534 | 848.275 | 2600.895 | 35156.849 | 1.029 | 0.074 | | (2, 16, 4096, 128) | causal_mask | torch.bfloat16 | 872.647 | 868.279 | 2587.581 | 31919.531 | 1.005 | 0.081 | | (2, 16, 4096, 128) | relative_bias | torch.bfloat16 | 871.484 | 827.644 | 2593.989 | 34805.634 | 1.053 | 0.075 | | (2, 16, 4096, 128) | head_bias | torch.bfloat16 | 871.422 | 856.437 | 2602.482 | 35708.591 | 1.017 | 0.073 | | (2, 16, 4096, 256) | noop | torch.bfloat16 | 1904.497 | 1758.183 | 6122.416 | 66754.593 | 1.083 | 0.092 | | (2, 16, 4096, 256) | causal_mask | torch.bfloat16 | 1911.174 | 1762.821 | 6113.207 | 72759.392 | 1.084 | 0.084 | | (2, 16, 4096, 256) | relative_bias | torch.bfloat16 | 1911.254 | 1727.108 | 6123.530 | 66577.988 | 1.107 | 0.092 | | (2, 16, 4096, 256) | head_bias | torch.bfloat16 | 1916.977 | 1801.804 | 6118.158 | 67359.680 | 1.064 | 0.091 | | (8, 16, 512, 64) | noop | torch.bfloat16 | 44.984 | 43.974 | 170.276 | 262.259 | 1.023 | 0.649 | | (8, 16, 512, 64) | causal_mask | torch.bfloat16 | 45.001 | 46.265 | 170.509 | 274.893 | 0.973 | 0.620 | | (8, 16, 512, 64) | relative_bias | torch.bfloat16 | 45.466 | 48.211 | 170.606 | 262.759 | 0.943 | 0.649 | | (8, 16, 512, 64) | head_bias | torch.bfloat16 | 45.481 | 48.435 | 170.267 | 261.265 | 0.939 | 0.652 | | (8, 16, 512, 128) | noop | torch.bfloat16 | 72.565 | 74.736 | 313.220 | 773.126 | 0.971 | 0.405 | | (8, 16, 512, 128) | causal_mask | torch.bfloat16 | 72.015 | 75.755 | 313.311 | 775.513 | 0.951 | 0.404 | | (8, 16, 512, 128) | relative_bias | torch.bfloat16 | 72.105 | 74.189 | 313.806 | 769.238 | 0.972 | 0.408 | | (8, 16, 512, 128) | head_bias | torch.bfloat16 | 72.005 | 74.364 | 313.509 | 775.237 | 0.968 | 0.404 | | (8, 16, 512, 256) | noop | torch.bfloat16 | 138.656 | 165.453 | 663.707 | 2672.067 | 0.838 | 0.248 | | (8, 16, 512, 256) | causal_mask | torch.bfloat16 | 139.096 | 172.613 | 663.593 | 2926.538 | 0.806 | 0.227 | | (8, 16, 512, 256) | relative_bias | torch.bfloat16 | 139.500 | 168.417 | 663.938 | 2658.629 | 0.828 | 0.250 | | (8, 16, 512, 256) | head_bias | torch.bfloat16 | 139.776 | 173.549 | 662.920 | 2667.266 | 0.805 | 0.249 | | (8, 16, 1024, 64) | noop | torch.bfloat16 | 134.883 | 125.004 | 484.706 | 1195.254 | 1.079 | 0.406 | | (8, 16, 1024, 64) | causal_mask | torch.bfloat16 | 134.297 | 132.875 | 485.420 | 1234.953 | 1.011 | 0.393 | | (8, 16, 1024, 64) | relative_bias | torch.bfloat16 | 134.839 | 139.231 | 485.470 | 1198.556 | 0.968 | 0.405 | | (8, 16, 1024, 64) | head_bias | torch.bfloat16 | 133.822 | 136.449 | 485.608 | 1189.198 | 0.981 | 0.408 | | (8, 16, 1024, 128) | noop | torch.bfloat16 | 235.470 | 234.765 | 886.094 | 2662.944 | 1.003 | 0.333 | | (8, 16, 1024, 128) | causal_mask | torch.bfloat16 | 236.305 | 241.382 | 886.293 | 2646.984 | 0.979 | 0.335 | | (8, 16, 1024, 128) | relative_bias | torch.bfloat16 | 236.414 | 233.980 | 885.250 | 2642.178 | 1.010 | 0.335 | | (8, 16, 1024, 128) | head_bias | torch.bfloat16 | 237.176 | 239.040 | 885.754 | 2665.242 | 0.992 | 0.332 | | (8, 16, 1024, 256) | noop | torch.bfloat16 | 504.445 | 517.855 | 1978.956 | 9592.906 | 0.974 | 0.206 | | (8, 16, 1024, 256) | causal_mask | torch.bfloat16 | 502.428 | 536.002 | 1978.611 | 10607.342 | 0.937 | 0.187 | | (8, 16, 1024, 256) | relative_bias | torch.bfloat16 | 503.396 | 523.960 | 1977.993 | 9539.284 | 0.961 | 0.207 | | (8, 16, 1024, 256) | head_bias | torch.bfloat16 | 503.818 | 536.014 | 1980.131 | 9576.262 | 0.940 | 0.207 | | (8, 16, 4096, 64) | noop | torch.bfloat16 | 1970.139 | 1674.930 | 5750.940 | 16724.134 | 1.176 | 0.344 | | (8, 16, 4096, 64) | causal_mask | torch.bfloat16 | 1959.036 | 1775.056 | 5780.512 | 17390.350 | 1.104 | 0.332 | | (8, 16, 4096, 64) | relative_bias | torch.bfloat16 | 1947.198 | 1773.869 | 5780.643 | 16779.699 | 1.098 | 0.345 | | (8, 16, 4096, 64) | head_bias | torch.bfloat16 | 1963.935 | 1829.502 | 5780.018 | 16703.259 | 1.073 | 0.346 | | (8, 16, 4096, 128) | noop | torch.bfloat16 | 3582.711 | 3362.623 | 10436.069 | 36415.565 | 1.065 | 0.287 | | (8, 16, 4096, 128) | causal_mask | torch.bfloat16 | 3581.504 | 3499.472 | 10346.869 | 36164.959 | 1.023 | 0.286 | | (8, 16, 4096, 128) | relative_bias | torch.bfloat16 | 3589.779 | 3337.849 | 10529.621 | 36261.696 | 1.075 | 0.290 | | (8, 16, 4096, 128) | head_bias | torch.bfloat16 | 3602.265 | 3436.444 | 10458.660 | 36507.790 | 1.048 | 0.286 | | (8, 16, 4096, 256) | noop | torch.bfloat16 | 7695.923 | 7126.275 | 24643.009 | 140949.081 | 1.080 | 0.175 | | (8, 16, 4096, 256) | causal_mask | torch.bfloat16 | 7679.939 | 7186.252 | 24538.105 | 157156.067 | 1.069 | 0.156 | | (8, 16, 4096, 256) | relative_bias | torch.bfloat16 | 7681.374 | 6994.832 | 24549.713 | 140077.179 | 1.098 | 0.175 | | (8, 16, 4096, 256) | head_bias | torch.bfloat16 | 7679.822 | 7212.278 | 24627.823 | 140675.003 | 1.065 | 0.175 | | (16, 16, 512, 64) | noop | torch.bfloat16 | 80.126 | 78.291 | 333.719 | 541.165 | 1.023 | 0.617 | | (16, 16, 512, 64) | causal_mask | torch.bfloat16 | 80.065 | 81.696 | 333.779 | 551.113 | 0.980 | 0.606 | | (16, 16, 512, 64) | relative_bias | torch.bfloat16 | 80.138 | 86.715 | 333.364 | 542.118 | 0.924 | 0.615 | | (16, 16, 512, 64) | head_bias | torch.bfloat16 | 80.415 | 85.204 | 333.294 | 536.840 | 0.944 | 0.621 | | (16, 16, 512, 128) | noop | torch.bfloat16 | 134.964 | 138.025 | 607.093 | 1333.102 | 0.978 | 0.455 | | (16, 16, 512, 128) | causal_mask | torch.bfloat16 | 134.192 | 141.523 | 606.269 | 1424.318 | 0.948 | 0.426 | | (16, 16, 512, 128) | relative_bias | torch.bfloat16 | 135.711 | 138.639 | 606.283 | 1327.974 | 0.979 | 0.457 | | (16, 16, 512, 128) | head_bias | torch.bfloat16 | 135.552 | 140.555 | 607.107 | 1347.370 | 0.964 | 0.451 | | (16, 16, 512, 256) | noop | torch.bfloat16 | 275.113 | 315.144 | 1301.583 | 5268.153 | 0.873 | 0.247 | | (16, 16, 512, 256) | causal_mask | torch.bfloat16 | 274.867 | 328.106 | 1302.513 | 5770.594 | 0.838 | 0.226 | | (16, 16, 512, 256) | relative_bias | torch.bfloat16 | 276.052 | 321.770 | 1302.904 | 5241.920 | 0.858 | 0.249 | | (16, 16, 512, 256) | head_bias | torch.bfloat16 | 271.409 | 328.839 | 1302.142 | 5266.037 | 0.825 | 0.247 | | (16, 16, 1024, 64) | noop | torch.bfloat16 | 260.489 | 237.463 | 955.884 | 1817.558 | 1.097 | 0.526 | | (16, 16, 1024, 64) | causal_mask | torch.bfloat16 | 262.378 | 254.350 | 955.280 | 1843.807 | 1.032 | 0.518 | | (16, 16, 1024, 64) | relative_bias | torch.bfloat16 | 261.338 | 268.253 | 956.038 | 1820.036 | 0.974 | 0.525 | | (16, 16, 1024, 64) | head_bias | torch.bfloat16 | 262.153 | 264.156 | 956.023 | 1810.076 | 0.992 | 0.528 | | (16, 16, 1024, 128) | noop | torch.bfloat16 | 476.475 | 461.413 | 1760.578 | 4306.521 | 1.033 | 0.409 | | (16, 16, 1024, 128) | causal_mask | torch.bfloat16 | 473.794 | 479.178 | 1761.277 | 4619.439 | 0.989 | 0.381 | | (16, 16, 1024, 128) | relative_bias | torch.bfloat16 | 473.839 | 463.282 | 1758.692 | 4290.562 | 1.023 | 0.410 | | (16, 16, 1024, 128) | head_bias | torch.bfloat16 | 472.979 | 472.896 | 1763.086 | 4367.931 | 1.000 | 0.404 | | (16, 16, 1024, 256) | noop | torch.bfloat16 | 1014.184 | 1026.764 | 3922.997 | 19104.147 | 0.988 | 0.205 | | (16, 16, 1024, 256) | causal_mask | torch.bfloat16 | 1013.217 | 1039.046 | 3928.382 | 21086.281 | 0.975 | 0.186 | | (16, 16, 1024, 256) | relative_bias | torch.bfloat16 | 1008.519 | 1015.278 | 3922.133 | 18980.652 | 0.993 | 0.207 | | (16, 16, 1024, 256) | head_bias | torch.bfloat16 | 1011.360 | 1047.542 | 3931.245 | 19069.172 | 0.965 | 0.206 | | (16, 16, 4096, 64) | noop | torch.bfloat16 | 3929.850 | 3325.667 | 11411.704 | 23344.280 | 1.182 | 0.489 | | (16, 16, 4096, 64) | causal_mask | torch.bfloat16 | 3885.262 | 3581.544 | 11390.515 | 23725.639 | 1.085 | 0.480 | | (16, 16, 4096, 64) | relative_bias | torch.bfloat16 | 3865.737 | 3537.308 | 11489.901 | 23406.330 | 1.093 | 0.491 | | (16, 16, 4096, 64) | head_bias | torch.bfloat16 | 3880.530 | 3665.249 | 11484.411 | 23299.496 | 1.059 | 0.493 | | (16, 16, 4096, 128) | noop | torch.bfloat16 | 7030.306 | 6745.715 | 20621.264 | 57464.096 | 1.042 | 0.359 | | (16, 16, 4096, 128) | causal_mask | torch.bfloat16 | 7095.414 | 7034.385 | 20410.656 | 61660.511 | 1.009 | 0.331 | | (16, 16, 4096, 128) | relative_bias | torch.bfloat16 | 7084.779 | 6686.497 | 20315.161 | 57243.969 | 1.060 | 0.355 | | (16, 16, 4096, 128) | head_bias | torch.bfloat16 | 7075.367 | 6863.305 | 20494.385 | 58481.953 | 1.031 | 0.350 | | (16, 16, 4096, 256) | noop | torch.bfloat16 | 15612.741 | 14297.482 | 55306.847 | 281161.865 | 1.092 | 0.197 | | (16, 16, 4096, 256) | causal_mask | torch.bfloat16 | 15326.592 | 14263.878 | 55227.806 | 313063.232 | 1.075 | 0.176 | | (16, 16, 4096, 256) | relative_bias | torch.bfloat16 | 15297.963 | 14007.379 | 54558.029 | 279529.175 | 1.092 | 0.195 | | (16, 16, 4096, 256) | head_bias | torch.bfloat16 | 15216.160 | 14276.027 | 55081.581 | 280996.826 | 1.066 | 0.196 | </details> Pull Request resolved: pytorch#125515 Approved by: https://github.com/Chillee
- Loading branch information