Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use KV cache till input seq len for prefill phase #154

Merged
merged 5 commits into from
Apr 11, 2024

Conversation

puneeshkhanna
Copy link

Pad KV cache to full input + new tokens len for decode phase. Delete the KV cache used as inputs by HPU graphs after full prompt generation. Ensure KV cache is not returned as output tensor during decode phase. Deletion of KV cache input tensor used by HPU graphs needs to be protected by PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable.
All the changes are protected by bucket internal flag.

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Pad KV cache to full input + new tokens len for decode phase.
Delete the KV cache used as inputs by HPU graphs after full prompt generation.
Ensure KV cache is not returned as output tensor during decode phase.
Deletion of KV cache input tensor used by HPU graphs needs to be protected by
PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable.
All the changes are protected by bucket internal flag.

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
@puneeshkhanna
Copy link
Author

puneeshkhanna commented Apr 10, 2024

Updated command (remove --reuse_cache , setting PT_HPUGRAPH_DISABLE_TENSOR_CACHE=1 automatically taken care)

python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py --model_name_or_path /mnt/weka/data/pytorch/llama2/Llama-2-70b-hf/ --use_hpu_graphs --use_kv_cache --max_input_tokens 2048 --max_new_tokens 2048 --batch_size 200 --attn_softmax_bf16 --trim_logits --bf16 --warmup 2 --n_iterations 2 --limit_hpu_graphs --bucket_internal --bucket_size 128

Also requires pytorch-integration patch - https://gerrit.habana-labs.com/#/c/408363/

@dvarshney-habana dvarshney-habana merged commit 60b5d9b into HabanaAI:habana-main Apr 11, 2024
sushildubey171 pushed a commit that referenced this pull request Apr 12, 2024
* Use KV cache till input seq len for prefill phase.

Pad KV cache to full input + new tokens len for decode phase.
Delete the KV cache used as inputs by HPU graphs after full prompt generation.
Ensure KV cache is not returned as output tensor during decode phase.
Deletion of KV cache input tensor used by HPU graphs needs to be protected by
PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable.
All the changes are protected by bucket internal flag.

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>

* Revert initialization of KV cache

* Set PT_HPUGRAPH_DISABLE_TENSOR_CACHE flag

* remove os import

* remove commented print

---------

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
astachowiczhabana pushed a commit that referenced this pull request Apr 19, 2024
* Use KV cache till input seq len for prefill phase.

Pad KV cache to full input + new tokens len for decode phase.
Delete the KV cache used as inputs by HPU graphs after full prompt generation.
Ensure KV cache is not returned as output tensor during decode phase.
Deletion of KV cache input tensor used by HPU graphs needs to be protected by
PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable.
All the changes are protected by bucket internal flag.

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>

* Revert initialization of KV cache

* Set PT_HPUGRAPH_DISABLE_TENSOR_CACHE flag

* remove os import

* remove commented print

---------

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
astachowiczhabana pushed a commit that referenced this pull request Apr 22, 2024
* Use KV cache till input seq len for prefill phase.

Pad KV cache to full input + new tokens len for decode phase.
Delete the KV cache used as inputs by HPU graphs after full prompt generation.
Ensure KV cache is not returned as output tensor during decode phase.
Deletion of KV cache input tensor used by HPU graphs needs to be protected by
PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable.
All the changes are protected by bucket internal flag.

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>

* Revert initialization of KV cache

* Set PT_HPUGRAPH_DISABLE_TENSOR_CACHE flag

* remove os import

* remove commented print

---------

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
astachowiczhabana pushed a commit that referenced this pull request Apr 24, 2024
* Use KV cache till input seq len for prefill phase.

Pad KV cache to full input + new tokens len for decode phase.
Delete the KV cache used as inputs by HPU graphs after full prompt generation.
Ensure KV cache is not returned as output tensor during decode phase.
Deletion of KV cache input tensor used by HPU graphs needs to be protected by
PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable.
All the changes are protected by bucket internal flag.

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>

* Revert initialization of KV cache

* Set PT_HPUGRAPH_DISABLE_TENSOR_CACHE flag

* remove os import

* remove commented print

---------

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
astachowiczhabana pushed a commit that referenced this pull request Apr 24, 2024
* Use KV cache till input seq len for prefill phase.

Pad KV cache to full input + new tokens len for decode phase.
Delete the KV cache used as inputs by HPU graphs after full prompt generation.
Ensure KV cache is not returned as output tensor during decode phase.
Deletion of KV cache input tensor used by HPU graphs needs to be protected by
PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable.
All the changes are protected by bucket internal flag.

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>

* Revert initialization of KV cache

* Set PT_HPUGRAPH_DISABLE_TENSOR_CACHE flag

* remove os import

* remove commented print

---------

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
puneeshkhanna pushed a commit to puneeshkhanna/optimum-habana-fork that referenced this pull request May 2, 2024
* Use KV cache till input seq len for prefill phase.

Pad KV cache to full input + new tokens len for decode phase.
Delete the KV cache used as inputs by HPU graphs after full prompt generation.
Ensure KV cache is not returned as output tensor during decode phase.
Deletion of KV cache input tensor used by HPU graphs needs to be protected by
PT_HPUGRAPH_DISABLE_TENSOR_CACHE env variable.
All the changes are protected by bucket internal flag.

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>

* Revert initialization of KV cache

* Set PT_HPUGRAPH_DISABLE_TENSOR_CACHE flag

* remove os import

* remove commented print

---------

Signed-off-by: Puneesh Khanna <pkhanna@habana.ai>
@astachowiczhabana
Copy link

huggingface#1028

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants