Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep kv cache as list of tensors maybe better than one tensor #1562

Open
lingzhi98 opened this issue Apr 8, 2024 · 3 comments
Open

Keep kv cache as list of tensors maybe better than one tensor #1562

lingzhi98 opened this issue Apr 8, 2024 · 3 comments
Assignees
Labels
Gemma Gemma model specific issues stat:awaiting keras-eng type:feature New feature or request

Comments

@lingzhi98
Copy link

lingzhi98 commented Apr 8, 2024

Describe the bug
If we keep kv cache as list of tensors, there has no need to concatenate kv caches of each decoder blocks (https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/models/gemma/gemma_causal_lm.py#L225). It is helpful for model performance.

Expected behavior
Remove useless concatenation to improve performance.

@github-actions github-actions bot added the Gemma Gemma model specific issues label Apr 8, 2024
@lingzhi98 lingzhi98 changed the title Keep kv cache as list of tensors maybe better than a one tensor Keep kv cache as list of tensors maybe better than one tensor Apr 8, 2024
@lingzhi98
Copy link
Author

Spliting kv cache into key cache and value cache is also important (https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/models/gemma/gemma_attention.py#L166).

@SuryanarayanaY SuryanarayanaY added the type:feature New feature or request label Apr 8, 2024
@mattdangerw
Copy link
Member

mattdangerw commented Apr 8, 2024

@lingzhi98 thanks! We are planning some generation improvements so will definitely check this out. Agreed we can let performance be our guide. Probably particularly jax compiled performance.

Were you thinking of a specific backend/compiled with XLA/not compiled? What's motivating the suggestion?

@lingzhi98
Copy link
Author

lingzhi98 commented Apr 9, 2024

I use jax as keras backend. I have seen the concatenation become the main overhead if increasing batch size. Due to keep kv caches as one tensor, we need slice the kv cache to get corresponding key/value cache to compute attention output and then update cache. Dynamic update slice fusion will blocked by this slice op (https://github.com/openxla/xla/blob/main/xla/service/gpu/ir_emission_utils.cc#L472) and hurts performance again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues stat:awaiting keras-eng type:feature New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants