-
Notifications
You must be signed in to change notification settings - Fork 535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory requirements #7
Comments
If using vLLM for inference (PyTorch model, FP16), I believe we used:
For each of the above, serving 1 model with tensor parallelism over the given number of devices. With 8 80GB A100s, I think the limit was around 650K - 700K tokens. In vLLM, it prints out the max number of tokens supported by giving the number of blocks for caches allocated, so it should be easy to tell if you're using GPUs with different amounts of memory. For Jax, I'm not too sure what intermediate requirements were, but we needed a v4-256 to do inference on 1M tokens (full FP32 inference). I think more optimization can be made (e.g. half-precision, quantization, etc.) to make the requirements smaller. Even at full precision, the requirements seemed higher than I expected, and there might be some Jax / XLA optimizations to be made (e.g. keep it from padding certain dimensions, which we originally had a lot of trouble with). |
Any recommendation to run the model on smaller GPUs (T4). It runs out of memory (jax). |
@wilson1yan Can you share the shell/bash script for setting up the inference server via vLLM for PyTorch model, FP16?
|
I’m thinking an attention kernel optimization like top-k would be appropriate here. Could a user calculate their own position_ids and pass a subset of the tokens, maybe make multiple passes and drop tokens that don’t impact the results? |
Aren't those requirements a bit high in case of 7B w/ 32k context? Mistral 7B 0.2 (32k context) works absolutely fine on consumer grade GPUs (especially when using quantized versions, like high quality Q6_K GGUFs). |
It would be worth to provide the measured memory requirements for inference Text Models at 32K, 128K,256K,512K and 1M tokens context window in both PyTorch and JAX.
The text was updated successfully, but these errors were encountered: