Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Each model in `models.yaml` becomes an isolated Ray Serve deployment (`ModelDepl
- **Independent lifecycle** — one model crashing doesn't affect others
- **Per-model GPU budgeting** — `num_gpus` controls VRAM allocation (e.g. 0.70 for 70%)
- **Sequential startup** — models deploy one at a time to prevent memory spikes, ordered by tensor parallelism size (TP > 1 first)
- **Multi-deployment routing** — the same model name can appear multiple times with different configs (e.g. GPU + CPU). The gateway round-robins requests across all deployments sharing a name. Each deployment also supports `num_replicas` for scaling identical copies via Ray Serve's built-in load balancing

### Inference Loaders

Expand Down
33 changes: 33 additions & 0 deletions docs/model-configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Models are configured in `config/models.yaml`. Each entry defines one deployment
| `plugin` | string | Plugin module name (required when `loader: custom`); must be installed via `uv sync --extra <plugin>` |
| `num_gpus` | float | Fraction of a GPU to allocate (0.0–1.0); also sets vLLM `gpu_memory_utilization` |
| `num_cpus` | float | CPU units to allocate (default `0.1`) |
| `num_replicas` | int | Number of identical Ray Serve replicas for this deployment (default `1`) |
| `vllm_engine_kwargs` | object | Passed directly to the vLLM engine — see [vLLM engine args](https://docs.vllm.ai/en/latest/configuration/engine_args.html) |
| `diffusers_config` | object | Diffusers pipeline options (see below) |
| `plugin_config` | object | Plugin-specific options passed through to the plugin |
Expand Down Expand Up @@ -43,6 +44,38 @@ Example:
guidance_scale: 0.0
```

## Multi-Deployment Routing

You can run the same model on different hardware (e.g. GPU and CPU) by repeating the same `name` with different settings. The API exposes the model once under `/v1/models`, and round-robins requests across all deployments sharing that name.

Use `num_replicas` to scale identical copies of a single deployment (Ray Serve handles load balancing between replicas automatically).

```yaml
models:
# GPU instance with 2 replicas
- name: "kokoro"
model: "hexgrad/Kokoro-82M"
usecase: "tts"
loader: "custom"
plugin: "kokoro"
num_gpus: 0.07
num_replicas: 2
plugin_config:
onnx_provider: "CUDAExecutionProvider"

# CPU fallback
- name: "kokoro"
model: "hexgrad/Kokoro-82M"
usecase: "tts"
loader: "custom"
plugin: "kokoro"
num_gpus: 0
plugin_config:
onnx_provider: "CPUExecutionProvider"
```

In this example, requests to model `kokoro` are distributed across three backends: two GPU replicas and one CPU instance.

## Environment Variables

| Variable | Description | Default |
Expand Down
29 changes: 22 additions & 7 deletions start.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ray import serve
from ray.serve.config import HTTPOptions

from yasha.infer.infer_config import ModelLoader, YashaConfig, YashaModelConfig
from yasha.infer.infer_config import ModelLoader, ModelUsecase, YashaConfig, YashaModelConfig
from yasha.infer.model_deployment import ModelDeployment
from yasha.logging import configure_logging, get_logger
from yasha.openai.api import YashaAPI
Expand Down Expand Up @@ -116,19 +116,34 @@ def main():
# Deploy models one at a time. serve.run() blocks until the deployment reaches
# RUNNING, ensuring each model fully initialises (and releases its load-time
# memory spike) before the next one starts.
model_handles = {}
#
# Multiple config entries may share the same `name` (e.g. one on GPU, one on
# CPU). Each gets a unique Ray deployment name (`name-1`, `name-2`, …) and
# their handles are grouped under the shared API-facing name for round-robin
# routing in the gateway.
model_handles: dict[str, tuple[list, ModelUsecase]] = {}
name_counters: dict[str, int] = {}
for config in sorted_models:
logger.info("Deploying model: %s", config.name)
count = name_counters.get(config.name, 0) + 1
name_counters[config.name] = count
deployment_name = f"{config.name}-{count}"

logger.info("Deploying model: %s (deployment: %s)", config.name, deployment_name)
handle = serve.run(
ModelDeployment.options(
name=config.name,
name=deployment_name,
num_replicas=config.num_replicas,
ray_actor_options=build_actor_options(config),
).bind(config),
name=config.name,
name=deployment_name,
route_prefix=None, # not exposed via HTTP — accessed only via handle
)
logger.info("Model ready: %s", config.name)
model_handles[config.name] = (handle, config.usecase)
logger.info("Model ready: %s (deployment: %s)", config.name, deployment_name)

if config.name in model_handles:
model_handles[config.name][0].append(handle)
else:
model_handles[config.name] = ([handle], config.usecase)

logger.info("All models ready, starting API gateway...")
serve.run(
Expand Down
45 changes: 45 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,48 @@ def test_multi_model_config(self):
def test_empty_models_list(self):
config = YashaConfig(models=[])
assert len(config.models) == 0

def test_duplicate_names_allowed(self):
config = YashaConfig(
models=[
YashaModelConfig(
name="kokoro",
model="hexgrad/Kokoro-82M",
usecase=ModelUsecase.tts,
loader=ModelLoader.custom,
plugin="kokoro",
num_gpus=0.07,
),
YashaModelConfig(
name="kokoro",
model="hexgrad/Kokoro-82M",
usecase=ModelUsecase.tts,
loader=ModelLoader.custom,
plugin="kokoro",
num_gpus=0,
),
]
)
assert len(config.models) == 2
assert config.models[0].name == config.models[1].name == "kokoro"


class TestNumReplicas:
def test_default_num_replicas(self):
config = YashaModelConfig(
name="test",
model="some-model",
usecase=ModelUsecase.generate,
loader=ModelLoader.vllm,
)
assert config.num_replicas == 1

def test_custom_num_replicas(self):
config = YashaModelConfig(
name="test",
model="some-model",
usecase=ModelUsecase.generate,
loader=ModelLoader.vllm,
num_replicas=3,
)
assert config.num_replicas == 3
1 change: 1 addition & 0 deletions yasha/infer/infer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class YashaModelConfig(BaseModel):
plugin: str | None = None # only meaningful for loader='custom', silently ignored otherwise
num_gpus: float = 0
num_cpus: float = 0.1
num_replicas: int = 1
vllm_engine_kwargs: VllmEngineConfig = Field(default_factory=VllmEngineConfig)
transformers_config: TransformersConfig | None = None
diffusers_config: DiffusersConfig | None = None
Expand Down
10 changes: 7 additions & 3 deletions yasha/openai/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ def _error_response(result: ErrorResponse) -> JSONResponse:
@serve.deployment
@serve.ingress(app)
class YashaAPI:
def __init__(self, model_handles: dict[str, tuple[DeploymentHandle, ModelUsecase]]):
self.models = {name: handle for name, (handle, _) in model_handles.items()}
def __init__(self, model_handles: dict[str, tuple[list[DeploymentHandle], ModelUsecase]]):
self.models: dict[str, list[DeploymentHandle]] = {name: handles for name, (handles, _) in model_handles.items()}
self._counters: dict[str, int] = {name: 0 for name in self.models}
self.model_list = [OpenAiModelCard(id=name) for name in model_handles]
MODELS_LOADED.set(len(self.models)) # all models are RUNNING by this point

Expand All @@ -130,7 +131,10 @@ def _set_request_id(request_id: str) -> None:
def _get_handle(self, model_name: str | None) -> DeploymentHandle:
if model_name is None or model_name not in self.models:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND.value, detail="model not found")
return self.models[model_name]
handles = self.models[model_name]
idx = self._counters[model_name] % len(handles)
self._counters[model_name] += 1
return handles[idx]

async def _handle_response(
self,
Expand Down
Loading