diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index af508a1732..33a29f91dc 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -128,6 +128,12 @@ def setup_env(args): os.environ.setdefault("PT_HPU_LAZY_ACC_PAR_MODE", "0") os.environ.setdefault("PT_HPU_ENABLE_LAZY_COLLECTIVES", "true") + if args.use_hpu_graphs and args.limit_hpu_graphs and not args.reuse_cache \ + and args.bucket_internal: + # Based upon above conditions and below env variable, + # we can call HPU graphs clear_inputs(). + os.environ.setdefault("PT_HPUGRAPH_DISABLE_TENSOR_CACHE", "1") + # Tweak generation so that it runs faster on Gaudi from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index c95c1f033c..fdcf315cad 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -171,6 +171,28 @@ def _get_hpu_graphs_kwargs(self, model_kwargs): hpu_graphs_kwargs.update({"bypass_hpu_graphs": True}) return hpu_graphs_kwargs + def _pad_past_key_values(self, model_kwargs): + pad_amount = model_kwargs.get("kv_cache_pad_len" , 0) + print(f"PAD KV Cache by {pad_amount} tokens") + if model_kwargs["past_key_values"]: + for i in range(len(model_kwargs["past_key_values"])): + for j in range(len(model_kwargs["past_key_values"][i])): + if torch.is_tensor(model_kwargs["past_key_values"][i][j]): + model_kwargs["past_key_values"][i][j] = torch.nn.functional.pad(model_kwargs["past_key_values"][i][j], (0, 0, 0, pad_amount)) + if model_kwargs.get("lazy_mode" , False): + self.htcore_generation.mark_step() + + def _remove_past_key_values(self, model_kwargs): + if model_kwargs["past_key_values"]: + for i in range(len(model_kwargs["past_key_values"])): + for j in range(len(model_kwargs["past_key_values"][i])): + if torch.is_tensor(model_kwargs["past_key_values"][i][j]): + t = model_kwargs["past_key_values"][i][j] + del t + model_kwargs["past_key_values"][i][j] = None + del model_kwargs["past_key_values"] + model_kwargs["past_key_values"] = None + def _update_model_kwargs_for_generation( self, outputs: ModelOutput, @@ -185,10 +207,11 @@ def _update_model_kwargs_for_generation( """ # mark to identify starting from second token model_kwargs["first_token"] = False - # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format - ) + if not model_kwargs.get("pad_done", False): + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) if getattr(outputs, "state", None) is not None: model_kwargs["state"] = outputs.state @@ -603,11 +626,6 @@ def generate( assert generation_config.bucket_size if generation_config.bucket_internal: assert generation_config.bucket_size >= 0, "please set bucket_size to use bucket_internal" - assert generation_config.reuse_cache, "please set reuse_cache to use bucket_internal" - if generation_config.reuse_cache and not generation_config.bucket_internal: - assert ( - generation_config.bucket_size <= 0 - ), "please set bucket_internal along with reuse_cache and bucket_size" if generation_config.static_shapes: # Pad inputs to have static shapes during generation, this gives better performance than dynamic shapes on HPUs @@ -723,7 +741,9 @@ def generate( calculated_max_length, token_idx ) - model_kwargs["kv_cache_len"] = calculated_max_length + if generation_config.use_cache: + model_kwargs["kv_cache_len"] = calculated_max_length + model_kwargs["kv_cache_pad_len"] = generation_config.max_new_tokens if self.config.model_type in ["llama", "falcon", "mistral"]: if self.config.max_position_embeddings < calculated_max_length: @@ -1394,7 +1414,11 @@ def greedy_search( inc = iter(incrementor(bucket_size, prompt_len)) if bucket_size > 0: assert "position_ids" not in model_kwargs, "Untested path" + greedy_first = True + model_kwargs["pad_done"] = False + model_kwargs["lazy_mode"] = lazy_mode + while True: if lazy_mode: self.htcore_generation.mark_step() @@ -1417,7 +1441,6 @@ def greedy_search( ) # prepare model inputs - model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) @@ -1526,6 +1549,21 @@ def greedy_search( if this_peer_finished and not synced_gpus: break + if not model_kwargs.get("pad_done", False) and not model_kwargs.get("reuse_cache", False) \ + and bucket_internal: + # Pad the returned pask key values tensors from prefill phase forward run to maximum length + # before starting the decode phase. + self._pad_past_key_values(model_kwargs) + model_kwargs["pad_done"] = True + + if model_kwargs.get("use_hpu_graphs", False) and model_kwargs.get("limit_hpu_graphs", False) \ + and not model_kwargs.get("reuse_cache", False) and bucket_internal: + # Clear HPU graphs input tensors of the decode phase after the full generation while loop + print("CLEAR HPU GRAPH INPUTS OF DECODE PHASE") + self.clear_inputs() + # Delete past key value tensors + self._remove_past_key_values(model_kwargs) + hb_profer.stop() if streamer is not None: streamer.end() diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index e96bde90ba..26c5127235 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -146,6 +146,8 @@ def adapt_transformers_to_gaudi(): GaudiGenerationMixin.update_model_kwargs_for_bucketing ) transformers.generation.GenerationMixin._get_hpu_graphs_kwargs = GaudiGenerationMixin._get_hpu_graphs_kwargs + transformers.generation.GenerationMixin._pad_past_key_values = GaudiGenerationMixin._pad_past_key_values + transformers.generation.GenerationMixin._remove_past_key_values = GaudiGenerationMixin._remove_past_key_values transformers.generation.GenerationMixin._expand_inputs_for_generation = staticmethod( GaudiGenerationMixin._expand_inputs_for_generation ) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index c588b63309..eeeef37917 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -50,7 +50,6 @@ import habana_frameworks.torch.core as htcore - def gaudi_llama_rmsnorm_forward(self, hidden_states): """ Copied from LlamaRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -302,7 +301,8 @@ def pre_attn_forward( if past_key_value is None: past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) past_value = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) - past_key_value = (past_key, past_value) + # Return list instead of tuple + past_key_value = [past_key, past_value] key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) @@ -399,6 +399,11 @@ def pre_attn_forward( if not output_attentions: attn_weights = None + if not reuse_cache and token_idx is not None and cache_idx is not None and q_len == 1: + # Return only past key value shapes and not the tensors during decode phase (q len is 1) + # to avoid making past key values as persistent output tensors of HPU graphs. + past_key_value = (past_key_value[0].shape, past_key_value[1].shape) + return attn_output, attn_weights, past_key_value def attention_all_reduce(self, attn_output): @@ -732,6 +737,8 @@ def forward( False, use_flash_attention, flash_attention_recompute, + flash_attention_causal_mask, + None, use_fused_rope, ) else: @@ -898,6 +905,7 @@ def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, token_idx=None, **kwargs ): reuse_cache = kwargs.get("reuse_cache") + bucket_internal= kwargs.get("bucket_internal") if past_key_values is not None: if token_idx is not None: input_ids = torch.index_select(input_ids, 1, token_idx - 1) @@ -929,8 +937,9 @@ def prepare_inputs_for_generation( and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] - elif reuse_cache and token_idx is not None: - # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass + elif (reuse_cache or bucket_internal) and token_idx is not None: + # KV cache is pre allocated with reuse cache or will be padded with bucket internal + # hence for the 1st token we can slice the inputs till token idx for the fwd pass. input_ids = input_ids[:, :token_idx] attention_mask = attention_mask[:, :token_idx]