Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross-Attention with key_padding_mask #127

Open
jvend opened this issue Feb 22, 2023 · 1 comment
Open

Cross-Attention with key_padding_mask #127

jvend opened this issue Feb 22, 2023 · 1 comment

Comments

@jvend
Copy link

jvend commented Feb 22, 2023

Maybe I missed it, but I didn’t see any code using flash-enabled cross attention with key_padding_mask analogous to FlashAttention in flash_attn/flash_attention.py. Is there any reason this is the case? I have a working implementation (with the same structure as FlashAttention) and would be happy to submit a pull request if there's interest. Thanks for the great work!

@tridao
Copy link
Contributor

tridao commented Feb 22, 2023

I think we have some of that here. If you have sth easier to use, would love to see it.

In general for the best performance, one should remove all the padding tokens at the very beginning, pass through all the layers, and then add back the padding tokens so as to avoid wasting computation. We do that in our BERT implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants