You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
lingzhi98
changed the title
Keep kv cache as list of tensors maybe better than a one tensor
Keep kv cache as list of tensors maybe better than one tensor
Apr 8, 2024
@lingzhi98 thanks! We are planning some generation improvements so will definitely check this out. Agreed we can let performance be our guide. Probably particularly jax compiled performance.
Were you thinking of a specific backend/compiled with XLA/not compiled? What's motivating the suggestion?
I use jax as keras backend. I have seen the concatenation become the main overhead if increasing batch size. Due to keep kv caches as one tensor, we need slice the kv cache to get corresponding key/value cache to compute attention output and then update cache. Dynamic update slice fusion will blocked by this slice op (https://github.com/openxla/xla/blob/main/xla/service/gpu/ir_emission_utils.cc#L472) and hurts performance again.
Describe the bug
If we keep kv cache as list of tensors, there has no need to concatenate kv caches of each decoder blocks (https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/models/gemma/gemma_causal_lm.py#L225). It is helpful for model performance.
Expected behavior
Remove useless concatenation to improve performance.
The text was updated successfully, but these errors were encountered: