Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use KV cache till input seq len for prefill phase #154

Merged
merged 5 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ def setup_env(args):
os.environ.setdefault("PT_HPU_LAZY_ACC_PAR_MODE", "0")
os.environ.setdefault("PT_HPU_ENABLE_LAZY_COLLECTIVES", "true")

if args.use_hpu_graphs and args.limit_hpu_graphs and not args.reuse_cache \
and args.bucket_internal:
# Based upon above conditions and below env variable,
# we can call HPU graphs clear_inputs().
os.environ.setdefault("PT_HPUGRAPH_DISABLE_TENSOR_CACHE", "1")

# Tweak generation so that it runs faster on Gaudi
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi

Expand Down
60 changes: 49 additions & 11 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,28 @@ def _get_hpu_graphs_kwargs(self, model_kwargs):
hpu_graphs_kwargs.update({"bypass_hpu_graphs": True})
return hpu_graphs_kwargs

def _pad_past_key_values(self, model_kwargs):
pad_amount = model_kwargs.get("kv_cache_pad_len" , 0)
print(f"PAD KV Cache by {pad_amount} tokens")
if model_kwargs["past_key_values"]:
for i in range(len(model_kwargs["past_key_values"])):
for j in range(len(model_kwargs["past_key_values"][i])):
if torch.is_tensor(model_kwargs["past_key_values"][i][j]):
model_kwargs["past_key_values"][i][j] = torch.nn.functional.pad(model_kwargs["past_key_values"][i][j], (0, 0, 0, pad_amount))
if model_kwargs.get("lazy_mode" , False):
self.htcore_generation.mark_step()

def _remove_past_key_values(self, model_kwargs):
if model_kwargs["past_key_values"]:
for i in range(len(model_kwargs["past_key_values"])):
for j in range(len(model_kwargs["past_key_values"][i])):
if torch.is_tensor(model_kwargs["past_key_values"][i][j]):
t = model_kwargs["past_key_values"][i][j]
del t
model_kwargs["past_key_values"][i][j] = None
del model_kwargs["past_key_values"]
model_kwargs["past_key_values"] = None

def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
Expand All @@ -185,10 +207,11 @@ def _update_model_kwargs_for_generation(
"""
# mark to identify starting from second token
model_kwargs["first_token"] = False
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
if not model_kwargs.get("pad_done", False):
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state

Expand Down Expand Up @@ -603,11 +626,6 @@ def generate(
assert generation_config.bucket_size
if generation_config.bucket_internal:
assert generation_config.bucket_size >= 0, "please set bucket_size to use bucket_internal"
assert generation_config.reuse_cache, "please set reuse_cache to use bucket_internal"
if generation_config.reuse_cache and not generation_config.bucket_internal:
assert (
generation_config.bucket_size <= 0
), "please set bucket_internal along with reuse_cache and bucket_size"

if generation_config.static_shapes:
# Pad inputs to have static shapes during generation, this gives better performance than dynamic shapes on HPUs
Expand Down Expand Up @@ -723,7 +741,9 @@ def generate(
calculated_max_length,
token_idx
)
model_kwargs["kv_cache_len"] = calculated_max_length
if generation_config.use_cache:
model_kwargs["kv_cache_len"] = calculated_max_length
model_kwargs["kv_cache_pad_len"] = generation_config.max_new_tokens

if self.config.model_type in ["llama", "falcon", "mistral"]:
if self.config.max_position_embeddings < calculated_max_length:
Expand Down Expand Up @@ -1394,7 +1414,11 @@ def greedy_search(
inc = iter(incrementor(bucket_size, prompt_len))
if bucket_size > 0:
assert "position_ids" not in model_kwargs, "Untested path"

greedy_first = True
model_kwargs["pad_done"] = False
model_kwargs["lazy_mode"] = lazy_mode

while True:
if lazy_mode:
self.htcore_generation.mark_step()
Expand All @@ -1417,7 +1441,6 @@ def greedy_search(
)

# prepare model inputs
model_kwargs["lazy_mode"] = lazy_mode
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs)
Expand Down Expand Up @@ -1526,6 +1549,21 @@ def greedy_search(
if this_peer_finished and not synced_gpus:
break

if not model_kwargs.get("pad_done", False) and not model_kwargs.get("reuse_cache", False) \
and bucket_internal:
# Pad the returned pask key values tensors from prefill phase forward run to maximum length
# before starting the decode phase.
self._pad_past_key_values(model_kwargs)
model_kwargs["pad_done"] = True

if model_kwargs.get("use_hpu_graphs", False) and model_kwargs.get("limit_hpu_graphs", False) \
and not model_kwargs.get("reuse_cache", False) and bucket_internal:
# Clear HPU graphs input tensors of the decode phase after the full generation while loop
print("CLEAR HPU GRAPH INPUTS OF DECODE PHASE")
self.clear_inputs()
# Delete past key value tensors
self._remove_past_key_values(model_kwargs)

hb_profer.stop()
if streamer is not None:
streamer.end()
Expand Down
2 changes: 2 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def adapt_transformers_to_gaudi():
GaudiGenerationMixin.update_model_kwargs_for_bucketing
)
transformers.generation.GenerationMixin._get_hpu_graphs_kwargs = GaudiGenerationMixin._get_hpu_graphs_kwargs
transformers.generation.GenerationMixin._pad_past_key_values = GaudiGenerationMixin._pad_past_key_values
transformers.generation.GenerationMixin._remove_past_key_values = GaudiGenerationMixin._remove_past_key_values
transformers.generation.GenerationMixin._expand_inputs_for_generation = staticmethod(
GaudiGenerationMixin._expand_inputs_for_generation
)
Expand Down
17 changes: 13 additions & 4 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
import habana_frameworks.torch.core as htcore



def gaudi_llama_rmsnorm_forward(self, hidden_states):
"""
Copied from LlamaRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
Expand Down Expand Up @@ -302,7 +301,8 @@ def pre_attn_forward(
if past_key_value is None:
past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device)
past_value = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device)
past_key_value = (past_key, past_value)
# Return list instead of tuple
puneeshkhanna marked this conversation as resolved.
Show resolved Hide resolved
past_key_value = [past_key, past_value]
key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len)

Expand Down Expand Up @@ -399,6 +399,11 @@ def pre_attn_forward(
if not output_attentions:
attn_weights = None

if not reuse_cache and token_idx is not None and cache_idx is not None and q_len == 1:
# Return only past key value shapes and not the tensors during decode phase (q len is 1)
# to avoid making past key values as persistent output tensors of HPU graphs.
past_key_value = (past_key_value[0].shape, past_key_value[1].shape)

return attn_output, attn_weights, past_key_value

def attention_all_reduce(self, attn_output):
Expand Down Expand Up @@ -732,6 +737,8 @@ def forward(
False,
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
None,
use_fused_rope,
)
else:
Expand Down Expand Up @@ -898,6 +905,7 @@ def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, token_idx=None, **kwargs
):
reuse_cache = kwargs.get("reuse_cache")
bucket_internal= kwargs.get("bucket_internal")
if past_key_values is not None:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
Expand Down Expand Up @@ -929,8 +937,9 @@ def prepare_inputs_for_generation(
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
elif reuse_cache and token_idx is not None:
# With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass
elif (reuse_cache or bucket_internal) and token_idx is not None:
# KV cache is pre allocated with reuse cache or will be padded with bucket internal
# hence for the 1st token we can slice the inputs till token idx for the fwd pass.
input_ids = input_ids[:, :token_idx]
attention_mask = attention_mask[:, :token_idx]

Expand Down