You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Maybe I missed it, but I didn’t see any code using flash-enabled cross attention with key_padding_mask analogous to FlashAttention in flash_attn/flash_attention.py. Is there any reason this is the case? I have a working implementation (with the same structure as FlashAttention) and would be happy to submit a pull request if there's interest. Thanks for the great work!
The text was updated successfully, but these errors were encountered:
I think we have some of that here. If you have sth easier to use, would love to see it.
In general for the best performance, one should remove all the padding tokens at the very beginning, pass through all the layers, and then add back the padding tokens so as to avoid wasting computation. We do that in our BERT implementation.
Maybe I missed it, but I didn’t see any code using flash-enabled cross attention with key_padding_mask analogous to FlashAttention in flash_attn/flash_attention.py. Is there any reason this is the case? I have a working implementation (with the same structure as FlashAttention) and would be happy to submit a pull request if there's interest. Thanks for the great work!
The text was updated successfully, but these errors were encountered: