-
Notifications
You must be signed in to change notification settings - Fork 12
Feature: Support downloading model weights on-the-fly from HuggingFace (#166) #167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
fc843ed
5f790ff
0f22bec
9f2fdd2
38011be
4de3563
eb1e929
8b6a211
c68cb35
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,8 +37,18 @@ def __init__(self, params: dict[str, Any]): | |
| self.additional_binds = ( | ||
| f",{self.params['bind']}" if self.params.get("bind") else "" | ||
| ) | ||
| self.model_weights_path = str( | ||
| Path(self.params["model_weights_parent_dir"], self.params["model_name"]) | ||
| model_weights_path = Path( | ||
| self.params["model_weights_parent_dir"], self.params["model_name"] | ||
| ) | ||
| self.model_weights_exists = model_weights_path.exists() | ||
| self.model_weights_path = str(model_weights_path) | ||
| self.model_source = ( | ||
| self.model_weights_path | ||
| if self.model_weights_exists | ||
| else self.params["model_name"] | ||
| ) | ||
| self.model_bind_option = ( | ||
| f",{self.model_weights_path}" if self.model_weights_exists else "" | ||
| ) | ||
| self.env_str = self._generate_env_str() | ||
|
|
||
|
|
@@ -111,7 +121,9 @@ def _generate_server_setup(self) -> str: | |
| server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_setup"])) | ||
| server_script.append( | ||
| SLURM_SCRIPT_TEMPLATE["bind_path"].format( | ||
| model_weights_path=self.model_weights_path, | ||
| model_weights_path=self.model_weights_path | ||
| if self.model_weights_exists | ||
| else "", | ||
| additional_binds=self.additional_binds, | ||
| ) | ||
| ) | ||
|
|
@@ -131,7 +143,6 @@ def _generate_server_setup(self) -> str: | |
| server_setup_str = server_setup_str.replace( | ||
| "CONTAINER_PLACEHOLDER", | ||
| SLURM_SCRIPT_TEMPLATE["container_command"].format( | ||
| model_weights_path=self.model_weights_path, | ||
| env_str=self.env_str, | ||
| ), | ||
| ) | ||
|
|
@@ -165,22 +176,27 @@ def _generate_launch_cmd(self) -> str: | |
| Server launch command. | ||
| """ | ||
| launcher_script = ["\n"] | ||
|
|
||
| vllm_args_copy = self.params["vllm_args"].copy() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if this is necessary, as the model name should be parsed with launch command not part of --vllm-args |
||
| model_source = self.model_source | ||
| if "--model" in vllm_args_copy: | ||
| model_source = vllm_args_copy.pop("--model") | ||
|
|
||
| if self.use_container: | ||
| launcher_script.append( | ||
| SLURM_SCRIPT_TEMPLATE["container_command"].format( | ||
| model_weights_path=self.model_weights_path, | ||
| env_str=self.env_str, | ||
| ) | ||
| ) | ||
|
|
||
| launcher_script.append( | ||
| "\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"]).format( | ||
| model_weights_path=self.model_weights_path, | ||
| model_source=model_source, | ||
| model_name=self.params["model_name"], | ||
| ) | ||
| ) | ||
|
|
||
| for arg, value in self.params["vllm_args"].items(): | ||
| for arg, value in vllm_args_copy.items(): | ||
| if isinstance(value, bool): | ||
| launcher_script.append(f" {arg} \\") | ||
| else: | ||
|
|
@@ -225,11 +241,20 @@ def __init__(self, params: dict[str, Any]): | |
| if self.params["models"][model_name].get("bind") | ||
| else "" | ||
| ) | ||
| self.params["models"][model_name]["model_weights_path"] = str( | ||
| Path( | ||
| self.params["models"][model_name]["model_weights_parent_dir"], | ||
| model_name, | ||
| ) | ||
| model_weights_path = Path( | ||
| self.params["models"][model_name]["model_weights_parent_dir"], | ||
| model_name, | ||
| ) | ||
| model_weights_exists = model_weights_path.exists() | ||
| model_weights_path_str = str(model_weights_path) | ||
| self.params["models"][model_name]["model_weights_path"] = ( | ||
| model_weights_path_str | ||
| ) | ||
| self.params["models"][model_name]["model_weights_exists"] = ( | ||
| model_weights_exists | ||
| ) | ||
| self.params["models"][model_name]["model_source"] = ( | ||
| model_weights_path_str if model_weights_exists else model_name | ||
| ) | ||
|
|
||
| def _write_to_log_dir(self, script_content: list[str], script_name: str) -> Path: | ||
|
|
@@ -266,7 +291,9 @@ def _generate_model_launch_script(self, model_name: str) -> Path: | |
| script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_setup"]) | ||
| script_content.append( | ||
| BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["bind_path"].format( | ||
| model_weights_path=model_params["model_weights_path"], | ||
| model_weights_path=model_params["model_weights_path"] | ||
| if model_params.get("model_weights_exists", True) | ||
| else "", | ||
| additional_binds=model_params["additional_binds"], | ||
| ) | ||
| ) | ||
|
|
@@ -283,19 +310,25 @@ def _generate_model_launch_script(self, model_name: str) -> Path: | |
| model_name=model_name, | ||
| ) | ||
| ) | ||
| vllm_args_copy = model_params["vllm_args"].copy() | ||
| model_source = model_params.get( | ||
| "model_source", model_params["model_weights_path"] | ||
| ) | ||
| if "--model" in vllm_args_copy: | ||
| model_source = vllm_args_copy.pop("--model") | ||
|
|
||
| if self.use_container: | ||
| script_content.append( | ||
| BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format( | ||
| model_weights_path=model_params["model_weights_path"], | ||
| ) | ||
| BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format() | ||
| ) | ||
| script_content.append( | ||
| "\n".join(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["launch_cmd"]).format( | ||
| model_weights_path=model_params["model_weights_path"], | ||
| model_source=model_source, | ||
| model_name=model_name, | ||
| ) | ||
| ) | ||
| for arg, value in model_params["vllm_args"].items(): | ||
|
|
||
| for arg, value in vllm_args_copy.items(): | ||
| if isinstance(value, bool): | ||
| script_content.append(f" {arg} \\") | ||
| else: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -98,7 +98,7 @@ class SlurmScriptTemplate(TypedDict): | |
| f"{CONTAINER_MODULE_NAME} exec {IMAGE_PATH} ray stop", | ||
| ], | ||
| "imports": "source {src_dir}/find_port.sh", | ||
| "bind_path": f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,/dev,/tmp,{{model_weights_path}}{{additional_binds}}", | ||
| "bind_path": f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,$(echo /dev/infiniband* | sed -e 's/ /,/g'),/dev,/tmp{{model_weights_path}}{{additional_binds}}", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like an error from merging the code? The /dev directory is already binded so the extra handling for /dev/infiniband is no longer needed |
||
| "container_command": f"{CONTAINER_MODULE_NAME} exec --nv {{env_str}} --containall {IMAGE_PATH} \\", | ||
| "activate_venv": "source {venv}/bin/activate", | ||
| "server_setup": { | ||
|
|
@@ -164,7 +164,7 @@ class SlurmScriptTemplate(TypedDict): | |
| ' && mv temp.json "$json_path"', | ||
| ], | ||
| "launch_cmd": [ | ||
| "vllm serve {model_weights_path} \\", | ||
| "vllm serve {model_source} \\", | ||
| " --served-model-name {model_name} \\", | ||
| ' --host "0.0.0.0" \\', | ||
| " --port $vllm_port_number \\", | ||
|
|
@@ -255,7 +255,7 @@ class BatchModelLaunchScriptTemplate(TypedDict): | |
| ], | ||
| "container_command": f"{CONTAINER_MODULE_NAME} exec --nv --containall {IMAGE_PATH} \\", | ||
| "launch_cmd": [ | ||
| "vllm serve {model_weights_path} \\", | ||
| "vllm serve {model_source} \\", | ||
| " --served-model-name {model_name} \\", | ||
| ' --host "0.0.0.0" \\', | ||
| " --port $vllm_port_number \\", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this member variable is never used anywhere?