-
-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP]: Initial DynamicFlexAttention wrapper class for dynamic sequence lengths #1960
base: main
Are you sure you want to change the base?
Conversation
See following gist https://gist.github.com/zyklotomic/527cb96da86c2b5f5984bede3be9b227 |
Hey! Great PR! Do you know why |
I have some interesting findings to report back! Should have dug deeper initially. Turns out getting dynamic shapes to work is something that has been worked on, and apparently is available in the nightly version of PyTorch. Links of interest: https://github.com/pytorch/pytorch/blob/8d08b4901586f230353a558ee00c16ad57f95178/torch/_inductor/kernel/flex_attention.py#L705 (most recent commit as of writing) -> which points to https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/flex_decoding.py#L336
I did try my example notebook and set Not a As for your question on why What do you think is the best course of action? Should we wait for the PyTorch folks to stabilize instead? |
I think I only just understood what you mean. If I understand correctly, my wrapper class handles the padding for you based on the input size. |
380b4c8
to
098642e
Compare
108cb95
to
275a743
Compare
https://colab.research.google.com/drive/1X7CpQgIqgRpV2aIUgS_p7u1TR4ITUfXF?usp=sharing It might be a bit primitive to use a temporary print statement to confirm that the flex attention module was indeed being invoked but don't think there was any better way. |
Had a stab at making Flex Attention work without excessive recompilation. I am not fully confident in this approach, it kinda feels jank to the max. Hence, I wanted to have confirmation if this is the right approach.
In essence, the kernel has to recompile every time the input sizes change. Hence, why not compile a kernel for a larger size, and pad inputs when necessary, and then splice the result before returning. See code for more thorough comments.
I haven't had the chance to really test the performance yet. There are potential enhancements too that I mention in the comments.
Will attach testing code for a demo in a bit.