Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Dynamic length in static cache #30862

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented May 16, 2024

What does this PR do?

The current version is a minimal change that works, maybe not the best way

Current static cache is nice (when running with torch.compile). However, in each generation step, the new position (to be generated) computes the attentions against all positions in the cache, which is not optimal. In fact, we only need to compute the attentions against the positions prior the current position.

This PR implement dynamic length computation with static cache, which work with torch.compile. The following table demonstrate the speedup gain (with torch.compile) of this implementation over the current main branch.

The correctness is verified by

RUN_SLOW=1 TF_FORCE_GPU_ALLOW_GROWTH=true python3 -m pytest -v tests/models/gemma/test_modeling_gemma.py -k "test_compile_static_cache"

The data below is based on

this script
import os
import torch
import datetime

from transformers import AutoTokenizer, AutoModelForCausalLM

token = "ADD_YOUR_OWN_TOKEN"

os.environ["TOKENIZERS_PARALLELISM"] = "false"

batch_size = 1
n_iter = 5

ckpt = "google/gemma-2b"

tokenizer = AutoTokenizer.from_pretrained(ckpt, token=token)
model = AutoModelForCausalLM.from_pretrained(ckpt, token=token, torch_dtype=torch.float16).to("cuda")

model.generation_config.max_new_tokens = 1024
model.generation_config.max_new_tokens = 1024

model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

input_text = "Why dogs are cute."
input_ids = tokenizer([input_text] * batch_size, return_tensors="pt").to("cuda")

for i in range(n_iter):
    s = datetime.datetime.now()
    outputs = model.generate(**input_ids, do_sample=False)
    t = datetime.datetime.now()
    e = (t-s).total_seconds()
    print(e)

with some modification to run it with different configurations, running on A100 with torch==2.3+cu121.

Benchmark

I will re-run (part of) the benchmark as the following numbers are on top of of an older commit of main

benchmark data on the hub

Static cache compiled: full length v.s. optimal length (this PR)

gemma-2b (18 layers)

seq. length speedup
1024 1.03 x
2048 1.11 x
4096 1.24 x
8192 1.38 x

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -1218,6 +1231,7 @@ def prepare_inputs_for_generation(
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"_length": int(cache_position[-1]) + 1,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is redundant of cache_position, however, this is the only way I can figure out to make the dynamic length computation works with torch.compile.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH I also think it makes sense slicing useless ops, I wondered about this question myself :) The speedup of 15% is very nice, and I can confirm the speedup on my setup as well! (RTX3090) 🔥

Regarding the API (_length): I understand why it is done. Without an interger in the signature of GemmaSdpaAttention.forward, slicing tensors like key_states will always fail due to it being a data-dependent operation (=forbidden) OR producing a variable length tensor. With an integer, each value for the integer has its own compiled function with data-independent tensor slicing.

Still, if we are to go forward, we should find a better solution for the API. StaticCache already introduced the cache_position input, this would further complicate the API. I see three possible paths:

  1. cache_position becomes a list of integers instead of a tensor, we use cache_position[-1] + 1 to slice the tensors;
  2. we pass the full cache_position array (a torch.arange up to the sequence length). The different shape of cache_position in each GemmaSdpaAttention.forward will trigger recompilation, solving the dynamic shape problem
  3. instead of cache_position, we use the sequence length (=_length, an int) to control generation with static cache.

Note that in all 3 cases, the StaticCache needs a tensor like the current cache_position. However, it should be trivial to build from any of the solutions above. From a usage perspective, option 3 is probably the easiest to understand. @ArthurZucker @ydshieh WDYT?

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 21, 2024

Regarding the API (_length): I understand why it is done. Without an interger in the signature of GemmaSdpaAttention.forward, slicing tensors like key_states will always fail due to it being a data-dependent operation (=forbidden) OR producing a variable length tensor.

Exactly!

For option 2, I am a bit worried. For example,

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Currently, cache_position is only 1 element (after the first step). If we go for optioin2, it will be full length. Then we are updating the whole cache. Of course, the key_states, value_states arguments in the update is just the last part (in the sequence), and we will have to slice cache_position here too. So the issue of data-dependent operation still pop up here.

I personally would prefer option 3 for its simplicity as long as we can re-build the tensor (cache_position) that is required by update and other places requiring it. Would like to hear from @ArthurZucker too.

@ArthurZucker
Copy link
Collaborator

Hey before having a look, when you mention speedups, I don't think it makes sense to compute anything that does not use the full number of layers.
Also how many tokens are generated? Is this speedup only for the prefill phase?

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 23, 2024

@ArthurZucker I am running on A10, so even with gemma-2 (18 layers), I can only compile with 768 sequence length.
However, as you can see from the tables, more layers more speedup, and longer sequence more speedup too.

Also how many tokens are generated?

from 256 to 8192 (as long as it could compile within A10 GPU memory).

The speedup gain and the reason behind it is kind easy to see. However, if there Is any extra particular case(s) you want me to perform?

@ArthurZucker
Copy link
Collaborator

That is what was not clear for me I wanted to know the amount of generated tokens not the prefill 😉
And most importantly, the new argument is pretty annoying 😓

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 23, 2024

Yes. We probably need to come up with a good new approach as @gante suggested.
I will run full layers (18) for google/gemma-2 in the meantime.

@ydshieh ydshieh mentioned this pull request May 23, 2024
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Waiting for the final bench!

Comment on lines +575 to +581
if _length > 0:
key_states = key_states[:, :, :_length, :]
value_states = value_states[:, :, :_length, :]
causal_mask = causal_mask[:, :, :, :_length] if causal_mask is not None else causal_mask
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can only be an int, if it's a list, there is bound to be device transfer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, so far it is int. Let me run the final bench first and come back to the API design

@zucchini-nlp
Copy link
Member

@ydshieh I tried to use your implementation in my PR. I am also trying to get the actual length in compiled models, but in my case length is used to decide which rope scaling to do. Therefore, passing length as model kwarg fails with dynamic control flow in fullgraph setting.

So, what do you guys think on going back to seen_tokens attribute of the cache class. It doesn't cause control flow errors because cache is still a model attribute, and we update seen_tokens as we generate instead of passing every forward pass. And I think it will work for kv-cropping done in this PR

cc @gante @ArthurZucker ?

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 24, 2024

@ydshieh I tried to use your implementation in my PR. I am also trying to get the actual length in compiled models, but in my case length is used to decide which rope scaling to do. Therefore, passing length as model kwarg fails with dynamic control flow in fullgraph setting.

but in my case length is used to decide which rope scaling to do. Therefore, passing length as model kwarg fails with dynamic control flow in fullgraph setting.

Could you ping the lines where length is used + where the compile issues. I could take a look.

Oh, you use the length as conditional?

@zucchini-nlp
Copy link
Member

Yes, in Phi3 RoPE it's used as conditional and I've been trying to compile it

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 24, 2024

Do you still have that commit (where you interoperate your PR with mine and leads to compile failure) ? If so, could you share please 🙏

@zucchini-nlp
Copy link
Member

Sorry, I reverted your changes back but I just pushed the one which works for me, with "seen_tokens". I get the length here and then use it in self.rotary_emb

@zucchini-nlp
Copy link
Member

Update: we discussed with @ydshieh using the length in cond control flow. It works indeed, but only in torch 2.3.0 or.2.4.0. In the 2.2.0 it would fail.

So this feature will also benefit Phi3 compilation, when merged :)

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 27, 2024

Hi @gante

When I run with TORCH_LOGS="recompiles" on main (95b3c381) and this PR (862cde4c), the only recompilation in both commits happens at the second call to the forward (see below) which makes sense.

So this PR doesn't introduce any extra recompilation.

(if we call generate with another input with different sequence, there would be one more recompilation. But after that, everything is ready to use and no further recompilation even if a 3rd input is given with different length)

V0527 12:12:26.470280 140416858965824 torch/_dynamo/guards.py:1425] [__recompiles] Recompiling function forward in /transformers/src/transformers/models/gemma/modeling_gemma.py:1058
V0527 12:12:26.470280 140416858965824 torch/_dynamo/guards.py:1425] [__recompiles]     triggered by the following guard failure(s):
V0527 12:12:26.470280 140416858965824 torch/_dynamo/guards.py:1425] [__recompiles]     - tensor 'L['input_ids']' stride mismatch at index 0. expected 6, actual 7
11.347857

@ydshieh ydshieh force-pushed the dynamic_length_in_static_cache branch from 862cde4 to b447901 Compare May 27, 2024 14:34
@ydshieh
Copy link
Collaborator Author

ydshieh commented May 28, 2024

it should be trivial to build from any of the solutions above

  1. cache_position becomes a list of integers instead of a tensor, we use cache_position[-1] + 1 to slice the tensors;

But we will need a tensor in StaticCache.update (that is what @ArthurZucker told me), so this option is not good I think.

  1. we pass the full cache_position array (a torch.arange up to the sequence length). The different shape of cache_position in each GemmaSdpaAttention.forward will trigger recompilation, solving the dynamic shape problem
  1. instead of cache_position, we use the sequence length (=_length, an int) to control generation with static cache.

Given a length (say _length) or a full cache_position along, it's not enough to reconstruct the (current) cache_position. The problem is that we don't know if we are in the first generation step or the steps after it in order to determine if we want to reconstruct a full cache_position or a single (current) position to be used in StaticCache.update.

We can probably use q_len, but it is obtained from a input tensor. I don't know if this will work well with torch.compile.

@gante Do you have any comment regarding this and something you think I could give it a try?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants