Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Masking + biasing #17

Open
gahdritz opened this issue Jul 3, 2022 · 25 comments
Open

Masking + biasing #17

gahdritz opened this issue Jul 3, 2022 · 25 comments

Comments

@gahdritz
Copy link
Contributor

gahdritz commented Jul 3, 2022

How difficult would it be to modify the kernels to support arbitrary masks/additive biases for the attention logits, especially w/ support for tensor broadcasting? Is there any fundamental reason why that wouldn't work? I noticed that the FlashAttention class has an "attn_mask" parameter, but it doesn't let you specify it.

On an unrelated note, would it be worth adding FP32 support for inference?

@gahdritz gahdritz changed the title Masking Masking + biasing Jul 3, 2022
@tridao
Copy link
Contributor

tridao commented Jul 3, 2022

We plan to support additive biases (e.g. ALiBi). The bias should ideally take linear memory and not quadratic memory.
Can you explain what shape the arbitrary masks would be? What does tensor broadcasting mean in this case?

Re FP32: We use tensor cores for matrix multiply, which can support fp16 (and bf16 in the near future). The xformers team have some implementation of memory-efficient attention in fp32 (here and here).

@gahdritz
Copy link
Contributor Author

gahdritz commented Jul 3, 2022

Thanks for the speedy response!

I'm thinking of applying FlashAttention to our implementation of AlphaFold 2, which has a number of different attention modules with different biases for the pre-softmax quadratic attention matrix S = Q @ K^T. To save memory, the biases are deliberately designed to be smaller than the full e.g. [B, H, N, N] S tensor---they might have shapes like [1, 1, N, N], or [1, 1, 1, N], and so on, and are only broadcast to S's size at the moment of addition.

BTW, is there a reason the existing attn_mask parameter is disabled? Would it be straightforward to get it re-enabled again?

@tridao
Copy link
Contributor

tridao commented Jul 3, 2022

Can you point me to the code / paper on how these biases are applied? Are they just added to S (potentially with broadcasting)? Are they fixed or trainable (i.e. do we need to compute their gradient in the backward pass)?

attn_mask is not implemented (yet), it's just there in the Python code to be compatible with our internal testing code.
We do support key_padding_mask.

@gahdritz
Copy link
Contributor Author

gahdritz commented Jul 3, 2022

See e.g. our implementation of what DeepMind calls "triangular attention." Pseudocode on page 19 here. It uses both a fixed mask and a trainable bias (mask_bias and triangle_bias in the code, respectively). Both are simply added to S with broadcasting.

@tridao
Copy link
Contributor

tridao commented Jul 3, 2022

Thanks for the pointers!
Is the mask_bias used to mask out the keys (it should have shape [1, 1, 1, N]), or to mask out arbitrary entries in the attention matrix (shape [1, 1, N, N])?
How large is the sequence length N typically?

@gahdritz
Copy link
Contributor Author

gahdritz commented Jul 3, 2022

In this particular case, the mask_bias masks out the keys (but is unique for each element in the (sub-)batch, so it's of shape [B, 1, 1, N]). The triangle_bias, however, is quadratic. N in this case usually ranges from 0-2000, but can go as high as ~10000.

@gahdritz
Copy link
Contributor Author

gahdritz commented Jul 4, 2022

Sorry I forgot to ask yesterday---do you have an approximate ETA for the attn_mask feature? Even without the additive bias feature, FlashAttention would give us a huge speedup in certain modules---we just need that mask.

Thanks!

@tridao
Copy link
Contributor

tridao commented Jul 4, 2022

Do you mean a key padding mask (of shape [B, 1, 1, N]), or an arbitrary attention mask of shape [1, 1, N, N]?
We already support key padding mask ([B, 1, 1, N]), query padding mask ([B, 1, N, 1]) and causal mask (for autoregressive modeling).

@gahdritz
Copy link
Contributor Author

gahdritz commented Jul 4, 2022

I tried playing around with the key_padding_mask feature, and as I understand it, in addition to masking out keys, it also masks out the corresponding queries and values (it stacks them all and then uses the unpad_input function to unpad all three at once). Is that incorrect? I might be looking in the wrong place, since I didn't see a specific query padding mask option. I'm looking for the ability to use the [B, 1, 1, N] mask to mask (i.e. set to negative infinity before the softmax) columns of S.

@tridao
Copy link
Contributor

tridao commented Jul 4, 2022

I've just finished implementing cross-attention where we can support separate query padding and key padding (here and here).
That is, you can all unpad_input on the query and then unpad_input on the keys and values separately.

I'll push an example usage tomorrow.

@gahdritz
Copy link
Contributor Author

gahdritz commented Jul 4, 2022

Oh awesome. Thanks!

@void-main
Copy link

We want customizable masking & biasing as well! Adding these two features would make FlashAttention suitable for a lot more models.

@OlivierDehaene
Copy link

Hello @tridao,
Do you still plan on adding support for additive biases? I am currently working on speeding up inference for BLOOM and having the possibility to use flash attention would be great!

@tridao
Copy link
Contributor

tridao commented Sep 23, 2022

Yes, it'll be there eventually. I just haven't had as munch bandwidth to work on it recently (conference deadline).

By inference, do you mean text generation? Would Q have seqlen 1 most of the time, due to the kv_cache? In that case I think there might not be much benefit to using FlashAttention (attention is not IO-bound).
We have ideas on speeding up text generation in this regime, but that will take a bit of time.

@guolinke
Copy link

refer to #57

@zeliu98
Copy link

zeliu98 commented Nov 5, 2022

Hi @tridao, congratulations on your great work!

I have the same issue for using flash attention in Swin-Transformer, especially in shifted window attention.
It will be appreciated if flash attention can support customizing attention mask and bias.

To be more specific, the shapes in the window attention are:
q, k, v: [batch-size, nheads, ntokens, dim]
attn=q@k.T(): [batch-size, nheads, ntokens, ntokens]
mask: [batch-size, 1, ntokens, ntokens]
bias: [1, nheads, ntokens, ntokens]

The final attn weights (before softmax) should be: attn + mask + bias

@MayDomine
Copy link

refer to #76

@calebthomas259
Copy link
Contributor

Hello @tridao, flash-attention is amazing! Thank you so much for making it!

If possible, I'd also like to request a fully customisable attention bias (i.e. shape [B, H, N, N]). This would allow for implementing papers like Graphormer (e.g. see Figure 1 of the linked paper).

Thank you for all of your hard work!

@MayDomine
Copy link

MayDomine commented Jan 4, 2023

Hello @tridao, flash-attention is amazing! Thank you so much for making it!

If possible, I'd also like to request a fully customisable attention bias (i.e. shape [B, H, N, N]). This would allow for implementing papers like Graphormer (e.g. see Figure 1 of the linked paper).

Thank you for all of your hard work!

Actually You can simply use the triton version implmentation to achieve this. But for backward you have to modify the origin code .And the performance will decrease because you have to save bias grad in hbm during backward.

@calebthomas259
Copy link
Contributor

Hello @tridao, flash-attention is amazing! Thank you so much for making it!
If possible, I'd also like to request a fully customisable attention bias (i.e. shape [B, H, N, N]). This would allow for implementing papers like Graphormer (e.g. see Figure 1 of the linked paper).
Thank you for all of your hard work!

Actually You can simply use the triton version implmentation to achieve this. But for backward you have to modify the origin code .And the performance will decrease because you have to save bias grad in hbm during backward.

I did try this at some point, but I was getting errors (I'm not sure whether it was my code being wrong, or just triton bugs). I'll probably try again once triton is more stable

@UCC-team
Copy link

I've just finished implementing cross-attention where we can support separate query padding and key padding (here and here). That is, you can all unpad_input on the query and then unpad_input on the keys and values separately.

I'll push an example usage tomorrow.

Where can we find the example?
Thank you for all of your hard work!

@tridao
Copy link
Contributor

tridao commented Aug 17, 2023

You can look at our BERT implementation.

@UCC-team
Copy link

You can look at our BERT implementation.

If convert a sparse mask to cu_seqlens and max_seqlen using unpad_input, we will get incorrect results. The results of attention_ref(q, k, v, q_mask, k_mask) and flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) do not match.
(The sparse mask is not a padding mask; it might appear in the middle rather than the end.)

_, _, cu_seqlens_q, max_seqlen_q = unpad_input(q, q_mask)
_, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, k_mask)

How should we solve this problem?Thanks!

@tridao
Copy link
Contributor

tridao commented Aug 29, 2023

If the mask is in the middle of the sequence, that's not supported right now.

@UCC-team
Copy link

If the mask is in the middle of the sequence, that's not supported right now.

Can the blockmask be used as any mask mentioned above in the flash_blocksparse_attn_func interface?
flash_blocksparse_attn_func

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants