Skip to content
57 changes: 57 additions & 0 deletions tests/vec_inf/client/test_slurm_script_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
)


@pytest.fixture(autouse=True)
def patch_model_weights_exists(monkeypatch):
"""Ensure model weights directory existence checks default to True."""
monkeypatch.setattr(
"vec_inf.client._slurm_script_generator.Path.exists", lambda self: True
)


class TestSlurmScriptGenerator:
"""Tests for SlurmScriptGenerator class."""

Expand Down Expand Up @@ -168,6 +176,21 @@ def test_generate_server_setup_singularity(self, singularity_params):
"module load " in setup
) # Remove module name since it's inconsistent between clusters

def test_generate_server_setup_singularity_no_weights(
self, singularity_params, monkeypatch
):
"""Test server setup when model weights don't exist."""
monkeypatch.setattr(
"vec_inf.client._slurm_script_generator.Path.exists",
lambda self: False,
)

generator = SlurmScriptGenerator(singularity_params)
setup = generator._generate_server_setup()

assert "ray stop" in setup
assert "/path/to/model_weights/test-model" not in setup

def test_generate_launch_cmd_venv(self, basic_params):
"""Test launch command generation with virtual environment."""
generator = SlurmScriptGenerator(basic_params)
Expand All @@ -187,6 +210,22 @@ def test_generate_launch_cmd_singularity(self, singularity_params):
assert "apptainer exec --nv" in launch_cmd
assert "source" not in launch_cmd

def test_generate_launch_cmd_singularity_no_local_weights(
self, singularity_params, monkeypatch
):
"""Test container launch when model weights directory is missing."""
monkeypatch.setattr(
"vec_inf.client._slurm_script_generator.Path.exists",
lambda self: False,
)

generator = SlurmScriptGenerator(singularity_params)
launch_cmd = generator._generate_launch_cmd()

assert "exec --nv" in launch_cmd
assert "--bind /path/to/model_weights/test-model" not in launch_cmd
assert "vllm serve test-model" in launch_cmd

def test_generate_launch_cmd_boolean_args(self, basic_params):
"""Test launch command with boolean vLLM arguments."""
params = basic_params.copy()
Expand Down Expand Up @@ -391,6 +430,24 @@ def test_generate_model_launch_script_singularity(
mock_touch.assert_called_once()
mock_write_text.assert_called_once()

@patch("pathlib.Path.touch")
@patch("pathlib.Path.write_text")
def test_generate_model_launch_script_singularity_no_weights(
self, mock_write_text, mock_touch, batch_singularity_params, monkeypatch
):
"""Test batch model launch script when model weights don't exist."""
monkeypatch.setattr(
"vec_inf.client._slurm_script_generator.Path.exists",
lambda self: False,
)

generator = BatchSlurmScriptGenerator(batch_singularity_params)
script_path = generator._generate_model_launch_script("model1")

assert script_path.name == "launch_model1.sh"
call_args = mock_write_text.call_args[0][0]
assert "/path/to/model_weights/model1" not in call_args

@patch("vec_inf.client._slurm_script_generator.datetime")
@patch("pathlib.Path.touch")
@patch("pathlib.Path.write_text")
Expand Down
69 changes: 51 additions & 18 deletions vec_inf/client/_slurm_script_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,18 @@ def __init__(self, params: dict[str, Any]):
self.additional_binds = (
f",{self.params['bind']}" if self.params.get("bind") else ""
)
self.model_weights_path = str(
Path(self.params["model_weights_parent_dir"], self.params["model_name"])
model_weights_path = Path(
self.params["model_weights_parent_dir"], self.params["model_name"]
)
self.model_weights_exists = model_weights_path.exists()
self.model_weights_path = str(model_weights_path)
self.model_source = (
self.model_weights_path
if self.model_weights_exists
else self.params["model_name"]
)
self.model_bind_option = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this member variable is never used anywhere?

f",{self.model_weights_path}" if self.model_weights_exists else ""
)
self.env_str = self._generate_env_str()

Expand Down Expand Up @@ -111,7 +121,9 @@ def _generate_server_setup(self) -> str:
server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_setup"]))
server_script.append(
SLURM_SCRIPT_TEMPLATE["bind_path"].format(
model_weights_path=self.model_weights_path,
model_weights_path=self.model_weights_path
if self.model_weights_exists
else "",
additional_binds=self.additional_binds,
)
)
Expand All @@ -131,7 +143,6 @@ def _generate_server_setup(self) -> str:
server_setup_str = server_setup_str.replace(
"CONTAINER_PLACEHOLDER",
SLURM_SCRIPT_TEMPLATE["container_command"].format(
model_weights_path=self.model_weights_path,
env_str=self.env_str,
),
)
Expand Down Expand Up @@ -165,22 +176,27 @@ def _generate_launch_cmd(self) -> str:
Server launch command.
"""
launcher_script = ["\n"]

vllm_args_copy = self.params["vllm_args"].copy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is necessary, as the model name should be parsed with launch command not part of --vllm-args

model_source = self.model_source
if "--model" in vllm_args_copy:
model_source = vllm_args_copy.pop("--model")

if self.use_container:
launcher_script.append(
SLURM_SCRIPT_TEMPLATE["container_command"].format(
model_weights_path=self.model_weights_path,
env_str=self.env_str,
)
)

launcher_script.append(
"\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"]).format(
model_weights_path=self.model_weights_path,
model_source=model_source,
model_name=self.params["model_name"],
)
)

for arg, value in self.params["vllm_args"].items():
for arg, value in vllm_args_copy.items():
if isinstance(value, bool):
launcher_script.append(f" {arg} \\")
else:
Expand Down Expand Up @@ -225,11 +241,20 @@ def __init__(self, params: dict[str, Any]):
if self.params["models"][model_name].get("bind")
else ""
)
self.params["models"][model_name]["model_weights_path"] = str(
Path(
self.params["models"][model_name]["model_weights_parent_dir"],
model_name,
)
model_weights_path = Path(
self.params["models"][model_name]["model_weights_parent_dir"],
model_name,
)
model_weights_exists = model_weights_path.exists()
model_weights_path_str = str(model_weights_path)
self.params["models"][model_name]["model_weights_path"] = (
model_weights_path_str
)
self.params["models"][model_name]["model_weights_exists"] = (
model_weights_exists
)
self.params["models"][model_name]["model_source"] = (
model_weights_path_str if model_weights_exists else model_name
)

def _write_to_log_dir(self, script_content: list[str], script_name: str) -> Path:
Expand Down Expand Up @@ -266,7 +291,9 @@ def _generate_model_launch_script(self, model_name: str) -> Path:
script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_setup"])
script_content.append(
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["bind_path"].format(
model_weights_path=model_params["model_weights_path"],
model_weights_path=model_params["model_weights_path"]
if model_params.get("model_weights_exists", True)
else "",
additional_binds=model_params["additional_binds"],
)
)
Expand All @@ -283,19 +310,25 @@ def _generate_model_launch_script(self, model_name: str) -> Path:
model_name=model_name,
)
)
vllm_args_copy = model_params["vllm_args"].copy()
model_source = model_params.get(
"model_source", model_params["model_weights_path"]
)
if "--model" in vllm_args_copy:
model_source = vllm_args_copy.pop("--model")

if self.use_container:
script_content.append(
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format(
model_weights_path=model_params["model_weights_path"],
)
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format()
)
script_content.append(
"\n".join(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["launch_cmd"]).format(
model_weights_path=model_params["model_weights_path"],
model_source=model_source,
model_name=model_name,
)
)
for arg, value in model_params["vllm_args"].items():

for arg, value in vllm_args_copy.items():
if isinstance(value, bool):
script_content.append(f" {arg} \\")
else:
Expand Down
6 changes: 3 additions & 3 deletions vec_inf/client/_slurm_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class SlurmScriptTemplate(TypedDict):
f"{CONTAINER_MODULE_NAME} exec {IMAGE_PATH} ray stop",
],
"imports": "source {src_dir}/find_port.sh",
"bind_path": f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,/dev,/tmp,{{model_weights_path}}{{additional_binds}}",
"bind_path": f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,$(echo /dev/infiniband* | sed -e 's/ /,/g'),/dev,/tmp{{model_weights_path}}{{additional_binds}}",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like an error from merging the code? The /dev directory is already binded so the extra handling for /dev/infiniband is no longer needed

"container_command": f"{CONTAINER_MODULE_NAME} exec --nv {{env_str}} --containall {IMAGE_PATH} \\",
"activate_venv": "source {venv}/bin/activate",
"server_setup": {
Expand Down Expand Up @@ -164,7 +164,7 @@ class SlurmScriptTemplate(TypedDict):
' && mv temp.json "$json_path"',
],
"launch_cmd": [
"vllm serve {model_weights_path} \\",
"vllm serve {model_source} \\",
" --served-model-name {model_name} \\",
' --host "0.0.0.0" \\',
" --port $vllm_port_number \\",
Expand Down Expand Up @@ -255,7 +255,7 @@ class BatchModelLaunchScriptTemplate(TypedDict):
],
"container_command": f"{CONTAINER_MODULE_NAME} exec --nv --containall {IMAGE_PATH} \\",
"launch_cmd": [
"vllm serve {model_weights_path} \\",
"vllm serve {model_source} \\",
" --served-model-name {model_name} \\",
' --host "0.0.0.0" \\',
" --port $vllm_port_number \\",
Expand Down