-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use KV cache till input seq len for prefill phase #154
Use KV cache till input seq len for prefill phase #154
Conversation
Pad KV cache to full input + new tokens len for decode phase. Delete the KV cache used as inputs by HPU graphs after full prompt generation. Ensure KV cache is not returned as output tensor during decode phase. Deletion of KV cache input tensor used by HPU graphs needs to be protected by PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable. All the changes are protected by bucket internal flag. Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
Updated command (remove --reuse_cache , setting PT_HPUGRAPH_DISABLE_TENSOR_CACHE=1 automatically taken care) python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py --model_name_or_path /mnt/weka/data/pytorch/llama2/Llama-2-70b-hf/ --use_hpu_graphs --use_kv_cache --max_input_tokens 2048 --max_new_tokens 2048 --batch_size 200 --attn_softmax_bf16 --trim_logits --bf16 --warmup 2 --n_iterations 2 --limit_hpu_graphs --bucket_internal --bucket_size 128 Also requires pytorch-integration patch - https://gerrit.habana-labs.com/#/c/408363/ |
* Use KV cache till input seq len for prefill phase. Pad KV cache to full input + new tokens len for decode phase. Delete the KV cache used as inputs by HPU graphs after full prompt generation. Ensure KV cache is not returned as output tensor during decode phase. Deletion of KV cache input tensor used by HPU graphs needs to be protected by PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable. All the changes are protected by bucket internal flag. Signed-off-by: Puneesh Khanna <pkhanna@habana.ai> * Revert initialization of KV cache * Set PT_HPUGRAPH_DISABLE_TENSOR_CACHE flag * remove os import * remove commented print --------- Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
* Use KV cache till input seq len for prefill phase. Pad KV cache to full input + new tokens len for decode phase. Delete the KV cache used as inputs by HPU graphs after full prompt generation. Ensure KV cache is not returned as output tensor during decode phase. Deletion of KV cache input tensor used by HPU graphs needs to be protected by PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable. All the changes are protected by bucket internal flag. Signed-off-by: Puneesh Khanna <pkhanna@habana.ai> * Revert initialization of KV cache * Set PT_HPUGRAPH_DISABLE_TENSOR_CACHE flag * remove os import * remove commented print --------- Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
* Use KV cache till input seq len for prefill phase. Pad KV cache to full input + new tokens len for decode phase. Delete the KV cache used as inputs by HPU graphs after full prompt generation. Ensure KV cache is not returned as output tensor during decode phase. Deletion of KV cache input tensor used by HPU graphs needs to be protected by PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable. All the changes are protected by bucket internal flag. Signed-off-by: Puneesh Khanna <pkhanna@habana.ai> * Revert initialization of KV cache * Set PT_HPUGRAPH_DISABLE_TENSOR_CACHE flag * remove os import * remove commented print --------- Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
* Use KV cache till input seq len for prefill phase. Pad KV cache to full input + new tokens len for decode phase. Delete the KV cache used as inputs by HPU graphs after full prompt generation. Ensure KV cache is not returned as output tensor during decode phase. Deletion of KV cache input tensor used by HPU graphs needs to be protected by PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable. All the changes are protected by bucket internal flag. Signed-off-by: Puneesh Khanna <pkhanna@habana.ai> * Revert initialization of KV cache * Set PT_HPUGRAPH_DISABLE_TENSOR_CACHE flag * remove os import * remove commented print --------- Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
* Use KV cache till input seq len for prefill phase. Pad KV cache to full input + new tokens len for decode phase. Delete the KV cache used as inputs by HPU graphs after full prompt generation. Ensure KV cache is not returned as output tensor during decode phase. Deletion of KV cache input tensor used by HPU graphs needs to be protected by PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable. All the changes are protected by bucket internal flag. Signed-off-by: Puneesh Khanna <pkhanna@habana.ai> * Revert initialization of KV cache * Set PT_HPUGRAPH_DISABLE_TENSOR_CACHE flag * remove os import * remove commented print --------- Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
* Use KV cache till input seq len for prefill phase. Pad KV cache to full input + new tokens len for decode phase. Delete the KV cache used as inputs by HPU graphs after full prompt generation. Ensure KV cache is not returned as output tensor during decode phase. Deletion of KV cache input tensor used by HPU graphs needs to be protected by PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable. All the changes are protected by bucket internal flag. Signed-off-by: Puneesh Khanna <pkhanna@habana.ai> * Revert initialization of KV cache * Set PT_HPUGRAPH_DISABLE_TENSOR_CACHE flag * remove os import * remove commented print --------- Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
Pad KV cache to full input + new tokens len for decode phase. Delete the KV cache used as inputs by HPU graphs after full prompt generation. Ensure KV cache is not returned as output tensor during decode phase. Deletion of KV cache input tensor used by HPU graphs needs to be protected by PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable.
All the changes are protected by bucket internal flag.
What does this PR do?
Fixes # (issue)
Before submitting