-
Notifications
You must be signed in to change notification settings - Fork 387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cache attention keys + values to speed up inference #216
Conversation
# NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly | ||
# with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute | ||
# scores correctly. | ||
or past_key_values is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why F.scaled_dot_product_attention()
doesn't give the right scores in this situation, but since this works now I stopped looking into it.
I haven't studied the code in detail, but on a glance it looks reasonable. I did try to rerun an experiment from earlier today with a 1B model, and it got identical predictions on all 2000 instances across two tasks, but much faster:
It's more than 12x faster for the longer (18 words avg, but with some longer outliers) DROP answers, and over 2x faster for the shorter (10 words avg) NaturalQs answers. |
Here are similar numbers on two summarization tasks:
again with identical metrics and notable (~6x) speedup. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I don't see any red flags in a cursory look over the code, and my end-to-end testing seems to indicate the code works as intended
Thank you @OyvindTafjord! Glad to hear there's a big speed up. |
Adds support for attention key/value caching and enables this in our
.generate()
method.