-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How did flash-attn compute attention for cu_seqlens #850
Comments
We just launch parallel work (i.e. thread blocks) to process each attn head of each sequence, and each thread block will figure out the start and end idx of each sequence from cu_seqlens. |
Got it. Thanks for the explanation! |
We observe that the flash API provides fwd/bwd and varlen_fwd/varlen_bwd API to handle inputs wo/w In the kernel of flash attention, a structure named BlockInfo is defined to store the offsets of qkv. These offsets are calculated based on Hence, each row of qkv would supports variable lengths since the gemm computation is break down into row-by-row computation, thereby preventing the waste of computational resources on meaningless padding tokens. |
Hi,
We know that
cu_seqlens
is for the compute efficiency when we do training over multiple variable-length samples. And the attention mask can be calculated throughcu_seqlens
. We can cut the original cumulative sequence into a batch sequence and pad the empty positions with zeros. But this approach will waste training efficiency, since computing resources are consumed for meaningless padding tokens.I am not quite familiar with the implementation details of flash-attn. So, I am just curious about where can I find the implementation or mechanism that how did flash-attn compute attention directly over cumulative sequence and get the separate results?
Thanks!
The text was updated successfully, but these errors were encountered: