-
Notifications
You must be signed in to change notification settings - Fork 18
Ray Disaggregated Serving MVP #106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
allenwang28
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High level comment - it looks like the main difference for is_disaggregated within PyTorchRayEngine is whether or not prefill returns outputs.
If the prefill/decode/interleave functionality is essentially the same, then I guess it's an implementation detail for orchestrator to trigger the transfer. If so, then it possible to exclude is_disaggregated from the worker? That'd simplify the complexity
Simplified the prefill call from engine side. On the worker side. Yes, they are same on insert and decode side. But I feel it's better to keep disaggregated and interleave for prefill. Several reasons:
|
I think that makes sense to me, thanks! |
This PR enable pytorch engine disaggregated serving on multiple TPU POD slices.
This PR delivered:
Result validation:
Command:
python /home/{user}/jetstream-pytorch/run_interactive_disaggregated.py --size=7b --batch_size=1 --is_disaggregated=True --num_hosts=8 --decode_pod_slice_name={user}-tpu-vm-2 --model_name=llama-2 --max_cache_length=2048 --quantize_weights=False --quantize_kv_cache=False --checkpoint_path=/home/{user}/data/llama-2-7b-chat-safetensor/model.safetensors --tokenizer_path=/home/{user}/data/tokenizer.model --sharding_config=/home/{user}/jetstream-pytorch/default_shardings/llama.yamlInterleave result:
Disaggregated result:
Next Steps:
5: Support multiple prefill engine and multiple decode engine