-
Notifications
You must be signed in to change notification settings - Fork 25.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Dynamic length in static cache #30862
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -1218,6 +1231,7 @@ def prepare_inputs_for_generation( | |||
"past_key_values": past_key_values, | |||
"use_cache": use_cache, | |||
"attention_mask": attention_mask, | |||
"_length": int(cache_position[-1]) + 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is redundant of cache_position
, however, this is the only way I can figure out to make the dynamic length computation works with torch.compile
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBH I also think it makes sense slicing useless ops, I wondered about this question myself :) The speedup of 15% is very nice, and I can confirm the speedup on my setup as well! (RTX3090) 🔥
Regarding the API (_length
): I understand why it is done. Without an interger in the signature of GemmaSdpaAttention.forward
, slicing tensors like key_states
will always fail due to it being a data-dependent operation (=forbidden) OR producing a variable length tensor. With an integer, each value for the integer has its own compiled function with data-independent tensor slicing.
Still, if we are to go forward, we should find a better solution for the API. StaticCache
already introduced the cache_position
input, this would further complicate the API. I see three possible paths:
cache_position
becomes a list of integers instead of a tensor, we usecache_position[-1] + 1
to slice the tensors;- we pass the full
cache_position
array (a torch.arange up to the sequence length). The different shape ofcache_position
in eachGemmaSdpaAttention.forward
will trigger recompilation, solving the dynamic shape problem - instead of
cache_position
, we use the sequence length (=_length
, anint
) to control generation with static cache.
Note that in all 3 cases, the StaticCache
needs a tensor like the current cache_position
. However, it should be trivial to build from any of the solutions above. From a usage perspective, option 3 is probably the easiest to understand. @ArthurZucker @ydshieh WDYT?
Exactly! For option 2, I am a bit worried. For example, if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) Currently, I personally would prefer option 3 for its simplicity as long as we can re-build the tensor ( |
Hey before having a look, when you mention speedups, I don't think it makes sense to compute anything that does not use the full number of layers. |
@ArthurZucker I am running on A10, so even with gemma-2 (18 layers), I can only compile with 768 sequence length.
from 256 to 8192 (as long as it could compile within A10 GPU memory). The speedup gain and the reason behind it is kind easy to see. However, if there Is any extra particular case(s) you want me to perform? |
That is what was not clear for me I wanted to know the amount of generated tokens not the prefill 😉 |
Yes. We probably need to come up with a good new approach as @gante suggested. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Waiting for the final bench!
if _length > 0: | ||
key_states = key_states[:, :, :_length, :] | ||
value_states = value_states[:, :, :_length, :] | ||
causal_mask = causal_mask[:, :, :, :_length] if causal_mask is not None else causal_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can only be an int, if it's a list, there is bound to be device transfer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, so far it is int. Let me run the final bench first and come back to the API design
@ydshieh I tried to use your implementation in my PR. I am also trying to get the actual length in compiled models, but in my case length is used to decide which rope scaling to do. Therefore, passing length as model kwarg fails with dynamic control flow in fullgraph setting. So, what do you guys think on going back to cc @gante @ArthurZucker ? |
Could you ping the lines where length is used + where the compile issues. I could take a look. Oh, you use the length as conditional? |
Yes, in Phi3 RoPE it's used as conditional and I've been trying to compile it |
Do you still have that commit (where you interoperate your PR with mine and leads to compile failure) ? If so, could you share please 🙏 |
Sorry, I reverted your changes back but I just pushed the one which works for me, with "seen_tokens". I get the length here and then use it in |
Update: we discussed with @ydshieh using the length in cond control flow. It works indeed, but only in torch 2.3.0 or.2.4.0. In the 2.2.0 it would fail. So this feature will also benefit Phi3 compilation, when merged :) |
Hi @gante When I run with So this PR doesn't introduce any extra recompilation. (if we call
|
862cde4
to
b447901
Compare
But we will need a tensor in
Given a length (say We can probably use @gante Do you have any comment regarding this and something you think I could give it a try? |
What does this PR do?
The current version is a minimal change that works, maybe not the best way
Current static cache is nice (when running with
torch.compile
). However, in each generation step, the new position (to be generated) computes the attentions against all positions in the cache, which is not optimal. In fact, we only need to compute the attentions against the positions prior the current position.This PR implement dynamic length computation with static cache, which work with
torch.compile
. The following table demonstrate the speedup gain (withtorch.compile
) of this implementation over the currentmain
branch.The correctness is verified by
The data below is based on
this script
with some modification to run it with different configurations, running on
A100
withtorch==2.3+cu121
.Benchmark
I will re-run (part of) the benchmark as the following numbers are on top of of an older commit of
main
benchmark data on the hub
Static cache compiled: full length v.s. optimal length (this PR)
gemma-2b (18 layers)