Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory requirements #7

Open
loretoparisi opened this issue Feb 14, 2024 · 5 comments
Open

Memory requirements #7

loretoparisi opened this issue Feb 14, 2024 · 5 comments

Comments

@loretoparisi
Copy link

It would be worth to provide the measured memory requirements for inference Text Models at 32K, 128K,256K,512K and 1M tokens context window in both PyTorch and JAX.

@wilson1yan
Copy link
Contributor

wilson1yan commented Feb 14, 2024

If using vLLM for inference (PyTorch model, FP16), I believe we used:

  • 1 80GB A100 for 32K
  • 2 80GB A100s for 128K
  • 4 80GB A100s for 256K
  • 8 80GB A100s for 512K

For each of the above, serving 1 model with tensor parallelism over the given number of devices. With 8 80GB A100s, I think the limit was around 650K - 700K tokens. In vLLM, it prints out the max number of tokens supported by giving the number of blocks for caches allocated, so it should be easy to tell if you're using GPUs with different amounts of memory.

For Jax, I'm not too sure what intermediate requirements were, but we needed a v4-256 to do inference on 1M tokens (full FP32 inference). I think more optimization can be made (e.g. half-precision, quantization, etc.) to make the requirements smaller. Even at full precision, the requirements seemed higher than I expected, and there might be some Jax / XLA optimizations to be made (e.g. keep it from padding certain dimensions, which we originally had a lot of trouble with).

@blazorin
Copy link

Any recommendation to run the model on smaller GPUs (T4). It runs out of memory (jax).

@Playerrrrr
Copy link

Playerrrrr commented Mar 10, 2024

@wilson1yan Can you share the shell/bash script for setting up the inference server via vLLM for PyTorch model, FP16?

If using vLLM for inference (PyTorch model, FP16), I believe we used:

* 1 80GB A100 for 32K

* 2 80GB A100s for 128K

* 4 80GB A100s for 256K

* 8 80GB A100s for 512K

For each of the above, serving 1 model with tensor parallelism over the given number of devices. With 8 80GB A100s, I think the limit was around 650K - 700K tokens. In vLLM, it prints out the max number of tokens supported by giving the number of blocks for caches allocated, so it should be easy to tell if you're using GPUs with different amounts of memory.

For Jax, I'm not too sure what intermediate requirements were, but we needed a v4-256 to do inference on 1M tokens (full FP32 inference). I think more optimization can be made (e.g. half-precision, quantization, etc.) to make the requirements smaller. Even at full precision, the requirements seemed higher than I expected, and there might be some Jax / XLA optimizations to be made (e.g. keep it from padding certain dimensions, which we originally had a lot of trouble with).

@xloem
Copy link

xloem commented Mar 10, 2024

I’m thinking an attention kernel optimization like top-k would be appropriate here. Could a user calculate their own position_ids and pass a subset of the tokens, maybe make multiple passes and drop tokens that don’t impact the results?

@MoonRide303
Copy link

Aren't those requirements a bit high in case of 7B w/ 32k context? Mistral 7B 0.2 (32k context) works absolutely fine on consumer grade GPUs (especially when using quantized versions, like high quality Q6_K GGUFs).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants