-
-
Notifications
You must be signed in to change notification settings - Fork 6.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC]: [V1] TPU support and multiple architecture support #12480
Comments
@bvrockwell feel free to CC relevant people from Google |
Thanks for the writing. To mitigate the change in scheduler for the case where chunked-prefill is not supported, is there any chance to implement a fake chunked-prefill wrapper (with existing kernels) and bring the coherent support when actual chunked-prefill support kicks in? (This may lead to suboptimal performance in a short period, but would likely to improve the overall code quality without hacking the scheduler in the upstream.) |
Regarding to the refactoring changes, can we separate the code change to a separate PR (e.g. a dedicated PR for device-agnostic refactoring for model runner class) ? |
@WoosukKwon to comment on model runner and worker. For scheduler, we should really think about how to deal with unsupported features. Maybe we should:
|
I have updated the PR (#11936) with Chunked Prompt support, so now there is no need to change anything in the scheduler - it is the same for both GPU and CPU. @WoosukKwon would be good to get your next set of feedback. |
Motivation.
We are in process of adding Google TPU support to the vLLM V1.
Here is the WIP PR #11936.
Since this is the first time we add another hardware backend to V1, the PR has some refactor to avoid code duplications, which requires discussion and feedback.
Proposed Change.
Here is the summary of changes this PR introduces:
Refactors the common logic of model_runner to model_runner_base.py in the folllowing way (Virtual functions in italic):
__init__() => Has common config init
get_model() => Just simply returns model
get_kv_cache_spec() => Common logic for KV cache management
initialize_kv_cache() => Virtual API
execute_model() => Virtual API
load_model() => Virtual API
dummy_run() => Virtual API
profile_run() => Virtual API
capture_model() => Virtual API
Refactors common logic of worker to worker_base.py in the following way (Virtual functions in italic):
__init__() => Has common config init, HF init, torch profiler init
load_model() => Calls load_model() of model_runner
compile_or_warm_up_model() => Calls capture model based on enforce_eager param and sets random seed
get_model() => Calls get_model() of model_runner
get_kv_cache_spec() => Calls get_kv_cache_spec() of model_runner
initialize_cache() => Calls initialize_kv_cache() of model_runner
profile() => Starts/stops profiler
check_health() => Empty function
init_device() => Virtual API
determine_available_memory() => Virtual API
execute_model() => Virtual API
Comments and feedback are very welcome.
Feedback Period.
one week
CC List.
@robertgshaw2-redhat @WoosukKwon @mgoin @tlrmchlsmth @youkaichao @simon-mo @njhill @comaniac @ywang96 @DarkLight1337 @SageMoore @bvrockwell
Any Other Things.
No response
Before submitting a new issue...
The text was updated successfully, but these errors were encountered: