Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache attention keys + values to speed up inference #216

Merged
merged 7 commits into from
Jun 20, 2023
Merged

Conversation

epwalsh
Copy link
Member

@epwalsh epwalsh commented Jun 19, 2023

Adds support for attention key/value caching and enables this in our .generate() method.

Comment on lines +691 to +694
# NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
# with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
# scores correctly.
or past_key_values is not None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why F.scaled_dot_product_attention() doesn't give the right scores in this situation, but since this works now I stopped looking into it.

@OyvindTafjord
Copy link
Contributor

I haven't studied the code in detail, but on a glance it looks reasonable. I did try to rerun an experiment from earlier today with a 1B model, and it got identical predictions on all 2000 instances across two tasks, but much faster:

01H3A8VA2615CAMJCB3VC0HG2H  NaturalQs: 322 sec, DROP: 6133 sec
01H3B5PNVWEB22RPRYY3QGRHTB  NaturalQs: 143 sec, DROP: 480 sec

It's more than 12x faster for the longer (18 words avg, but with some longer outliers) DROP answers, and over 2x faster for the shorter (10 words avg) NaturalQs answers.

@OyvindTafjord
Copy link
Contributor

Here are similar numbers on two summarization tasks:

01H30Y87MDV9T2KNNZBJ7S6SSW  SciTLDR: 2700 sec, XSum: 3300 sec
01H3BAC2WB01VYK04KWXPH5R9D  SciTLDR:  470 sec, XSum:  500 sec

again with identical metrics and notable (~6x) speedup.

Copy link
Contributor

@OyvindTafjord OyvindTafjord left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I don't see any red flags in a cursory look over the code, and my end-to-end testing seems to indicate the code works as intended

@epwalsh
Copy link
Member Author

epwalsh commented Jun 20, 2023

Thank you @OyvindTafjord! Glad to hear there's a big speed up.

@epwalsh epwalsh merged commit 7c866c9 into main Jun 20, 2023
10 checks passed
@epwalsh epwalsh deleted the petew-cache-attn branch June 20, 2023 16:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants