Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How did flash-attn compute attention for cu_seqlens #850

Closed
zigzagcai opened this issue Feb 23, 2024 · 3 comments
Closed

How did flash-attn compute attention for cu_seqlens #850

zigzagcai opened this issue Feb 23, 2024 · 3 comments

Comments

@zigzagcai
Copy link

zigzagcai commented Feb 23, 2024

Hi,

We know that cu_seqlens is for the compute efficiency when we do training over multiple variable-length samples. And the attention mask can be calculated through cu_seqlens. We can cut the original cumulative sequence into a batch sequence and pad the empty positions with zeros. But this approach will waste training efficiency, since computing resources are consumed for meaningless padding tokens.

I am not quite familiar with the implementation details of flash-attn. So, I am just curious about where can I find the implementation or mechanism that how did flash-attn compute attention directly over cumulative sequence and get the separate results?

Thanks!

@tridao
Copy link
Contributor

tridao commented Feb 23, 2024

We just launch parallel work (i.e. thread blocks) to process each attn head of each sequence, and each thread block will figure out the start and end idx of each sequence from cu_seqlens.

@zigzagcai
Copy link
Author

We just launch parallel work (i.e. thread blocks) to process each attn head of each sequence, and each thread block will figure out the start and end idx of each sequence from cu_seqlens.

Got it. Thanks for the explanation!

@zigzagcai
Copy link
Author

zigzagcai commented Mar 6, 2024

We just launch parallel work (i.e. thread blocks) to process each attn head of each sequence, and each thread block will figure out the start and end idx of each sequence from cu_seqlens.

We observe that the flash API provides fwd/bwd and varlen_fwd/varlen_bwd API to handle inputs wo/w cu_seqlen. Different patterns of inputs are all passed into run_mha_fwd and run_mha_bwd, and ultimately are processed by the flash_fwd_kernel and flash_bwd_kernel within the templates.

In the kernel of flash attention, a structure named BlockInfo is defined to store the offsets of qkv. These offsets are calculated based on cu_seqlen_q and cu_seqlen_k, allowing for the computation of attention on a row-by-row basis (compute_attn_1rowblock). Through BlockInfo, we can specify which thread block will compute attention on which row of qkv.

Hence, each row of qkv would supports variable lengths since the gemm computation is break down into row-by-row computation, thereby preventing the waste of computational resources on meaningless padding tokens.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants