From ff6cfb6ae124eda01835e1db9180d4bff9f765dd Mon Sep 17 00:00:00 2001 From: Alex M Date: Mon, 6 Apr 2026 04:03:18 +0000 Subject: [PATCH 1/2] feat: add diffusers loader for OpenAI-compatible image generation Add a new `diffusers` loader for plug-and-play image generation models via HuggingFace Diffusers' AutoPipelineForText2Image. Exposes a new POST /v1/images/generations endpoint returning base64-encoded PNGs. - Add `image` usecase and `diffusers` loader to config enums - Add DiffusersConfig (torch_dtype, num_inference_steps, guidance_scale) - Add DiffusersInfer backend following TransformersInfer pattern - Add OpenAIServingImage serving layer with run_in_executor for GPU calls - Add ImageGenerationRequest/Response protocol models - Add create_image_generation stubs to vllm/transformers/custom backends - Make model and loader mandatory fields in YashaModelConfig - Return all models in /v1/models regardless of usecase - Add accelerate and diffusers dependencies - Update docs and README --- README.md | 5 +- docs/architecture.md | 2 + docs/model-configuration.md | 31 ++++- pyproject.toml | 6 +- uv.lock | 42 +++++++ yasha/infer/custom/custom_infer.py | 8 ++ yasha/infer/diffusers/diffusers_infer.py | 108 ++++++++++++++++++ yasha/infer/diffusers/openai/serving_image.py | 79 +++++++++++++ yasha/infer/infer_config.py | 21 ++-- yasha/infer/model_deployment.py | 13 +++ .../infer/transformers/transformers_infer.py | 8 ++ yasha/infer/vllm/vllm_infer.py | 8 ++ yasha/openai/api.py | 23 +++- yasha/openai/protocol.py | 27 +++++ 14 files changed, 361 insertions(+), 20 deletions(-) create mode 100644 yasha/infer/diffusers/diffusers_infer.py create mode 100644 yasha/infer/diffusers/openai/serving_image.py diff --git a/README.md b/README.md index 70b65d2..1c712e4 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Yasha -Self-hosted, multi-model AI inference server. Runs LLMs alongside specialized models (TTS, speech-to-text, embeddings) on one or more GPUs, exposing an OpenAI-compatible API. Built on [vLLM](https://github.com/vllm-project/vllm) and [Ray](https://github.com/ray-project/ray). +Self-hosted, multi-model AI inference server. Runs LLMs alongside specialized models (TTS, speech-to-text, embeddings, image generation) on one or more GPUs, exposing an OpenAI-compatible API. Built on [vLLM](https://github.com/vllm-project/vllm) and [Ray](https://github.com/ray-project/ray). ## Architecture @@ -36,7 +36,7 @@ Each model runs as an isolated Ray Serve deployment with its own lifecycle, heal ## Features -- **Multi-model on a single GPU** — run chat, embedding, STT, and TTS models simultaneously with tunable per-model GPU memory allocation +- **Multi-model on a single GPU** — run chat, embedding, STT, TTS, and image generation models simultaneously with tunable per-model GPU memory allocation - **Per-model isolated deployments** — each model runs in its own Ray Serve deployment with independent lifecycle, health checks, and failure isolation - **OpenAI-compatible API** — drop-in replacement for any OpenAI SDK client - **Streaming** — SSE streaming for chat completions and TTS audio @@ -55,6 +55,7 @@ Each model runs as an isolated Ray Serve deployment with its own lifecycle, heal | `POST /v1/audio/transcriptions` | Speech-to-text | | `POST /v1/audio/translations` | Audio translation | | `POST /v1/audio/speech` | Text-to-speech (SSE streaming or single-response) | +| `POST /v1/images/generations` | Image generation | | `GET /v1/models` | List available models | ## Quick Start diff --git a/docs/architecture.md b/docs/architecture.md index 916390d..96a2c80 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -34,6 +34,7 @@ Each deployment uses one of three loaders: |--------|---------|-----------| | `vllm` | vLLM engine | Chat/generation, embeddings, transcription, translation | | `transformers` | PyTorch + HuggingFace | Custom model implementations | +| `diffusers` | HuggingFace Diffusers | Image generation (any `AutoPipelineForText2Image` model) | | `custom` | Plugin system | TTS backends (Kokoro, Bark, Orpheus) | ## GPU Allocation @@ -66,5 +67,6 @@ See [Plugin Development](plugins.md) for details. | `yasha/infer/model_deployment.py` | Ray Serve deployment actor | | `yasha/infer/infer_config.py` | Pydantic config models and protocols | | `yasha/infer/vllm/vllm_infer.py` | vLLM engine wrapper | +| `yasha/infer/diffusers/diffusers_infer.py` | Diffusers pipeline wrapper | | `yasha/plugins/base_plugin.py` | Plugin base classes | | `config/models.yaml` | Model configuration | diff --git a/docs/model-configuration.md b/docs/model-configuration.md index 2f69266..09fafd3 100644 --- a/docs/model-configuration.md +++ b/docs/model-configuration.md @@ -8,13 +8,14 @@ Models are configured in `config/models.yaml`. Each entry defines one deployment |---|---|---| | `name` | string | Model identifier used in API requests | | `model` | string | HuggingFace model ID | -| `usecase` | string | `generate`, `embed`, `transcription`, `translation`, or `tts` | -| `loader` | string | `vllm`, `transformers`, or `custom` | +| `usecase` | string | `generate`, `embed`, `transcription`, `translation`, `tts`, or `image` | +| `loader` | string | `vllm`, `transformers`, `diffusers`, or `custom` | | `plugin` | string | Plugin module name (required when `loader: custom`); must be installed via `uv sync --extra ` | | `num_gpus` | float | Fraction of a GPU to allocate (0.0–1.0); also sets vLLM `gpu_memory_utilization` | | `num_cpus` | float | CPU units to allocate (default `0.1`) | | `use_gpu` | int \| string | Pin to a specific GPU (see below) | | `vllm_engine_kwargs` | object | Passed directly to the vLLM engine — see [vLLM engine args](https://docs.vllm.ai/en/latest/configuration/engine_args.html) | +| `diffusers_config` | object | Diffusers pipeline options (see below) | | `plugin_config` | object | Plugin-specific options passed through to the plugin | ## GPU Pinning @@ -29,6 +30,32 @@ Models are configured in `config/models.yaml`. Each entry defines one deployment The name is arbitrary — it must match the value in `use_gpu`. The `models.example.2x16GB.yaml` preset uses `"dual_16gb"` for a TP=2 LLM deployment. - **omit** — Ray schedules the deployment freely across available GPUs +## Diffusers Config + +Options for `loader: diffusers` models (image generation via HuggingFace Diffusers): + +| Field | Type | Default | Description | +|---|---|---|---| +| `torch_dtype` | string | `float16` | Torch dtype (`float16`, `bfloat16`, `float32`) | +| `num_inference_steps` | int | `30` | Default denoising steps (can be overridden per request) | +| `guidance_scale` | float | `7.5` | Default classifier-free guidance scale (can be overridden per request) | + +Any model supported by `AutoPipelineForText2Image` works out of the box — Stable Diffusion 1.5/2.x/XL/3.x, SDXL Turbo, Flux, PixArt, Kandinsky, etc. + +Example: + +```yaml +- name: "sdxl-turbo" + model: "stabilityai/sdxl-turbo" + usecase: "image" + loader: "diffusers" + num_gpus: 0.35 + diffusers_config: + torch_dtype: "float16" + num_inference_steps: 4 + guidance_scale: 0.0 +``` + ## Environment Variables | Variable | Description | Default | diff --git a/pyproject.toml b/pyproject.toml index 71a5f77..1c378b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,13 @@ [project] name = "yasha" version = "0.1.14" -description = "Self-hosted, multi-model AI inference server. Run LLMs, TTS, STT, and embeddings with an OpenAI-compatible API." +description = "Self-hosted, multi-model AI inference server. Run LLMs, TTS, STT, embeddings, and image generation with an OpenAI-compatible API." authors = [ { name = "Alex Margarit" } ] readme = { file = "README.md", content-type = "text/markdown" } license = { text = "MIT" } -keywords = ["ai", "inference", "vllm", "ray", "openai", "llm", "tts", "stt", "embeddings", "self-hosted"] +keywords = ["ai", "inference", "vllm", "ray", "openai", "llm", "tts", "stt", "embeddings", "image-generation", "diffusers", "self-hosted"] classifiers = [ "Development Status :: 3 - Alpha", "License :: OSI Approved :: MIT License", @@ -34,6 +34,8 @@ dependencies = [ "vllm==0.18.0", "vllm[audio]==0.18.0", "requests>=2.32.5", + "accelerate>=1.6.0", + "diffusers>=0.31.0", ] [project.optional-dependencies] diff --git a/uv.lock b/uv.lock index 1b5a934..f1c6442 100644 --- a/uv.lock +++ b/uv.lock @@ -15,6 +15,24 @@ members = [ "yasha", ] +[[package]] +name = "accelerate" +version = "1.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "psutil" }, + { name = "pyyaml" }, + { name = "safetensors" }, + { name = "torch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ca/14/787e5498cd062640f0f3d92ef4ae4063174f76f9afd29d13fc52a319daae/accelerate-1.13.0.tar.gz", hash = "sha256:d631b4e0f5b3de4aff2d7e9e6857d164810dfc3237d54d017f075122d057b236", size = 402835, upload-time = "2026-03-04T19:34:12.359Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/46/02ac5e262d4af18054b3e922b2baedbb2a03289ee792162de60a865defc5/accelerate-1.13.0-py3-none-any.whl", hash = "sha256:cf1a3efb96c18f7b152eb0fa7490f3710b19c3f395699358f08decca2b8b62e0", size = 383744, upload-time = "2026-03-04T19:34:10.313Z" }, +] + [[package]] name = "aiohappyeyeballs" version = "2.6.1" @@ -536,6 +554,26 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cf/65/4df6936130b56e1429114e663e7c1576cf845f3aef1b2dd200c0a5d19dba/depyf-0.20.0-py3-none-any.whl", hash = "sha256:d31effad4261cebecb58955d832e448ace88f432328f95f82fd99c30fd9308d4", size = 39381, upload-time = "2025-10-13T12:33:33.647Z" }, ] +[[package]] +name = "diffusers" +version = "0.37.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "httpx" }, + { name = "huggingface-hub" }, + { name = "importlib-metadata" }, + { name = "numpy" }, + { name = "pillow" }, + { name = "regex" }, + { name = "requests" }, + { name = "safetensors" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/46/5c/f4c2eb8d481fe8784a7e2331fbaab820079c06676185fa6d2177b386d590/diffusers-0.37.1.tar.gz", hash = "sha256:2346c21f77f835f273b7aacbaada1c34a596a3a2cc6ddc99d149efcd0ec298fa", size = 4135139, upload-time = "2026-03-25T08:04:04.515Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9c/dd/51c38785ce5e1c287b5ad17ba550edaaaffce0deb0da4857019c6700fbaf/diffusers-0.37.1-py3-none-any.whl", hash = "sha256:0537c0b28cb53cf39d6195489bcf8f833986df556c10f5e28ab7427b86fc8b90", size = 5001536, upload-time = "2026-03-25T08:04:02.385Z" }, +] + [[package]] name = "dill" version = "0.4.1" @@ -3675,8 +3713,10 @@ name = "yasha" version = "0.1.14" source = { editable = "." } dependencies = [ + { name = "accelerate" }, { name = "argparse" }, { name = "asyncio" }, + { name = "diffusers" }, { name = "fastapi" }, { name = "flashinfer-python" }, { name = "httpx" }, @@ -3713,9 +3753,11 @@ orpheus = [ [package.metadata] requires-dist = [ + { name = "accelerate", specifier = ">=1.6.0" }, { name = "argparse", specifier = ">=1.4.0" }, { name = "asyncio", specifier = ">=4.0.0" }, { name = "bark", marker = "extra == 'bark'", editable = "plugins/bark" }, + { name = "diffusers", specifier = ">=0.31.0" }, { name = "fastapi", specifier = ">=0.116.1" }, { name = "flashinfer-python", specifier = ">=0.6.1" }, { name = "httpx", specifier = ">=0.28.1" }, diff --git a/yasha/infer/custom/custom_infer.py b/yasha/infer/custom/custom_infer.py index 025619a..d37db5f 100644 --- a/yasha/infer/custom/custom_infer.py +++ b/yasha/infer/custom/custom_infer.py @@ -12,6 +12,7 @@ EmbeddingRequest, ErrorInfo, ErrorResponse, + ImageGenerationRequest, RawSpeechResponse, SpeechRequest, TranscriptionRequest, @@ -79,3 +80,10 @@ async def create_speech( error=ErrorInfo(message="model does not support this action", type="invalid_request_error", code=404) ) return await self.serving_speech.create_speech(request, cast("Request", raw_request)) + + async def create_image_generation( + self, _request: ImageGenerationRequest, _raw_request: DisconnectProxy + ) -> ErrorResponse: + return ErrorResponse( + error=ErrorInfo(message="model does not support this action", type="invalid_request_error", code=404) + ) diff --git a/yasha/infer/diffusers/diffusers_infer.py b/yasha/infer/diffusers/diffusers_infer.py new file mode 100644 index 0000000..eb1902b --- /dev/null +++ b/yasha/infer/diffusers/diffusers_infer.py @@ -0,0 +1,108 @@ +import logging +from typing import cast + +import torch +from starlette.requests import Request + +from yasha.infer.diffusers.openai.serving_image import OpenAIServingImage +from yasha.infer.infer_config import DiffusersConfig, DisconnectProxy, ModelUsecase, YashaModelConfig +from yasha.openai.protocol import ( + ChatCompletionRequest, + EmbeddingRequest, + ErrorInfo, + ErrorResponse, + ImageGenerationRequest, + ImageGenerationResponse, + SpeechRequest, + TranscriptionRequest, + TranslationRequest, +) + +logger = logging.getLogger("ray") + +_TORCH_DTYPES = { + "float16": torch.float16, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + + +class DiffusersInfer: + def __init__(self, model_config: YashaModelConfig): + self.model_config = model_config + self.device = "cuda:0" if torch.cuda.is_available() else "cpu" + + if torch.cuda.is_available() and model_config.num_gpus < 1.0: + torch.cuda.set_per_process_memory_fraction(model_config.num_gpus) + + def __del__(self): + try: + if pipeline := getattr(self, "_pipeline", None): + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except Exception: + pass + + async def start(self): + from diffusers import AutoPipelineForText2Image + + config = self.model_config.diffusers_config or DiffusersConfig() + dtype = _TORCH_DTYPES.get(config.torch_dtype, torch.float16) + + logger.info( + "Loading diffusers pipeline: %s (dtype=%s, device=%s)", + self.model_config.model, + config.torch_dtype, + self.device, + ) + self._pipeline = AutoPipelineForText2Image.from_pretrained( + self.model_config.model, + torch_dtype=dtype, + ).to(self.device) + + self.serving_image: OpenAIServingImage | None = ( + OpenAIServingImage(pipeline=self._pipeline, config=config) + if self.model_config.usecase is ModelUsecase.image + else None + ) + + async def create_chat_completion( + self, _request: ChatCompletionRequest, _raw_request: DisconnectProxy + ) -> ErrorResponse: + return ErrorResponse( + error=ErrorInfo(message="model does not support this action", type="invalid_request_error", code=404) + ) + + async def create_embedding(self, _request: EmbeddingRequest, _raw_request: DisconnectProxy) -> ErrorResponse: + return ErrorResponse( + error=ErrorInfo(message="model does not support this action", type="invalid_request_error", code=404) + ) + + async def create_transcription( + self, _audio_data: bytes, _request: TranscriptionRequest, _raw_request: DisconnectProxy + ) -> ErrorResponse: + return ErrorResponse( + error=ErrorInfo(message="model does not support this action", type="invalid_request_error", code=404) + ) + + async def create_translation( + self, _audio_data: bytes, _request: TranslationRequest, _raw_request: DisconnectProxy + ) -> ErrorResponse: + return ErrorResponse( + error=ErrorInfo(message="model does not support this action", type="invalid_request_error", code=404) + ) + + async def create_speech(self, _request: SpeechRequest, _raw_request: DisconnectProxy) -> ErrorResponse: + return ErrorResponse( + error=ErrorInfo(message="model does not support this action", type="invalid_request_error", code=404) + ) + + async def create_image_generation( + self, request: ImageGenerationRequest, raw_request: DisconnectProxy + ) -> ErrorResponse | ImageGenerationResponse: + if self.serving_image is None: + return ErrorResponse( + error=ErrorInfo(message="model does not support this action", type="invalid_request_error", code=404) + ) + return await self.serving_image.create_image_generation(request, cast("Request", raw_request)) diff --git a/yasha/infer/diffusers/openai/serving_image.py b/yasha/infer/diffusers/openai/serving_image.py new file mode 100644 index 0000000..e339ca1 --- /dev/null +++ b/yasha/infer/diffusers/openai/serving_image.py @@ -0,0 +1,79 @@ +import asyncio +import base64 +import io +import logging +import time + +from diffusers import AutoPipelineForText2Image +from fastapi import Request + +from yasha.infer.infer_config import DiffusersConfig +from yasha.openai.protocol import ( + ImageGenerationRequest, + ImageGenerationResponse, + ImageObject, + create_error_response, +) +from yasha.utils import base_request_id + +logger = logging.getLogger("ray") + + +class OpenAIServingImage: + request_id_prefix = "img" + + def __init__(self, pipeline: AutoPipelineForText2Image, config: DiffusersConfig): + self.pipeline = pipeline + self.config = config + + async def create_image_generation( + self, request: ImageGenerationRequest, raw_request: Request + ) -> ImageGenerationResponse: + request_id = f"{self.request_id_prefix}-{base_request_id(raw_request)}" + logger.info( + "image generation request %s: prompt=%r, n=%d, size=%s", request_id, request.prompt, request.n, request.size + ) + + try: + width, height = _parse_size(request.size) + except ValueError as e: + return create_error_response(str(e)) + + steps = request.num_inference_steps or self.config.num_inference_steps + guidance = request.guidance_scale or self.config.guidance_scale + + loop = asyncio.get_event_loop() + images = await loop.run_in_executor( + None, + lambda: ( + self.pipeline( + prompt=request.prompt, + num_images_per_prompt=request.n, + width=width, + height=height, + num_inference_steps=steps, + guidance_scale=guidance, + ).images + ), + ) + + data = [] + for img in images: + buf = io.BytesIO() + img.save(buf, format="PNG") + b64 = base64.b64encode(buf.getvalue()).decode("utf-8") + data.append(ImageObject(b64_json=b64, revised_prompt=request.prompt)) + + return ImageGenerationResponse(created=int(time.time()), data=data) + + +def _parse_size(size: str) -> tuple[int, int]: + parts = size.lower().split("x") + if len(parts) != 2: + raise ValueError(f"Invalid size format '{size}', expected WxH (e.g. '512x512')") + w, h = int(parts[0]), int(parts[1]) + if w <= 0 or h <= 0: + raise ValueError(f"Width and height must be positive, got {w}x{h}") + if w % 8 != 0 or h % 8 != 0: + raise ValueError(f"Width and height must be multiples of 8, got {w}x{h}") + return w, h diff --git a/yasha/infer/infer_config.py b/yasha/infer/infer_config.py index 44610be..dc5b90e 100644 --- a/yasha/infer/infer_config.py +++ b/yasha/infer/infer_config.py @@ -15,11 +15,13 @@ class ModelUsecase(StrEnum): transcription = "transcription" translation = "translation" tts = "tts" + image = "image" class ModelLoader(StrEnum): vllm = "vllm" transformers = "transformers" + diffusers = "diffusers" custom = "custom" @@ -47,27 +49,26 @@ class TransformersConfig(BaseModel): device: str = "cpu" +class DiffusersConfig(BaseModel): + torch_dtype: str = "float16" + num_inference_steps: int = 30 + guidance_scale: float = 7.5 + + class YashaModelConfig(BaseModel): name: str - model: str | None = None + model: str usecase: ModelUsecase - loader: ModelLoader = ModelLoader.vllm + loader: ModelLoader plugin: str | None = None # only meaningful for loader='custom', silently ignored otherwise num_gpus: float = 0 num_cpus: float = 0.1 use_gpu: int | str | None = None vllm_engine_kwargs: VllmEngineConfig = Field(default_factory=VllmEngineConfig) transformers_config: TransformersConfig | None = None + diffusers_config: DiffusersConfig | None = None plugin_config: dict[str, Any] | None = None # plugin devs parse this themselves - @model_validator(mode="after") - def check_model_or_plugin(self): - if self.model is None and self.plugin is None: - raise ValueError("model and plugin fields cannot be both empty") - if self.loader in (ModelLoader.vllm, ModelLoader.transformers) and self.model is None: - raise ValueError(f"loader='{self.loader}' requires model to be set") - return self - @model_validator(mode="after") def check_custom_requires_plugin(self): if self.loader == ModelLoader.custom and self.plugin is None: diff --git a/yasha/infer/model_deployment.py b/yasha/infer/model_deployment.py index 21f8888..23132a3 100644 --- a/yasha/infer/model_deployment.py +++ b/yasha/infer/model_deployment.py @@ -5,12 +5,14 @@ from ray import serve from yasha.infer.custom.custom_infer import CustomInfer +from yasha.infer.diffusers.diffusers_infer import DiffusersInfer from yasha.infer.infer_config import DisconnectProxy, ModelLoader, YashaModelConfig from yasha.infer.transformers.transformers_infer import TransformersInfer from yasha.infer.vllm.vllm_infer import VllmInfer from yasha.openai.protocol import ( ChatCompletionRequest, EmbeddingRequest, + ImageGenerationRequest, SpeechRequest, TranscriptionRequest, TranslationRequest, @@ -27,6 +29,8 @@ async def __init__(self, config: YashaModelConfig): self.infer = VllmInfer(config) elif config.loader == ModelLoader.transformers: self.infer = TransformersInfer(config) + elif config.loader == ModelLoader.diffusers: + self.infer = DiffusersInfer(config) else: self.infer = CustomInfer(config) @@ -80,3 +84,12 @@ async def speak(self, request: SpeechRequest, request_headers: dict[str, str], d yield chunk else: yield result + + async def imagine(self, request: ImageGenerationRequest, request_headers: dict[str, str], disconnect_event: Any): + proxy = DisconnectProxy(disconnect_event, request_headers) + result = await self.infer.create_image_generation(request, proxy) + if isinstance(result, AsyncGenerator): + async for chunk in result: + yield chunk + else: + yield result diff --git a/yasha/infer/transformers/transformers_infer.py b/yasha/infer/transformers/transformers_infer.py index f92da95..d48a54c 100644 --- a/yasha/infer/transformers/transformers_infer.py +++ b/yasha/infer/transformers/transformers_infer.py @@ -13,6 +13,7 @@ EmbeddingRequest, ErrorInfo, ErrorResponse, + ImageGenerationRequest, RawSpeechResponse, SpeechRequest, TranscriptionRequest, @@ -96,3 +97,10 @@ async def create_speech( error=ErrorInfo(message="model does not support this action", type="invalid_request_error", code=404) ) return await self.serving_speech.create_speech(request, cast("Request", raw_request)) + + async def create_image_generation( + self, _request: ImageGenerationRequest, _raw_request: DisconnectProxy + ) -> ErrorResponse: + return ErrorResponse( + error=ErrorInfo(message="model does not support this action", type="invalid_request_error", code=404) + ) diff --git a/yasha/infer/vllm/vllm_infer.py b/yasha/infer/vllm/vllm_infer.py index 0248453..9816d77 100644 --- a/yasha/infer/vllm/vllm_infer.py +++ b/yasha/infer/vllm/vllm_infer.py @@ -25,6 +25,7 @@ EmbeddingRequest, ErrorInfo, ErrorResponse, + ImageGenerationRequest, RawSpeechResponse, SpeechRequest, TranscriptionRequest, @@ -286,3 +287,10 @@ async def create_speech( error=ErrorInfo(message="model does not support this action", type="invalid_request_error", code=404) ) return await self.serving_speech.create_speech(request, cast("Request", raw_request)) + + async def create_image_generation( + self, _request: ImageGenerationRequest, _raw_request: DisconnectProxy + ) -> ErrorResponse: + return ErrorResponse( + error=ErrorInfo(message="model does not support this action", type="invalid_request_error", code=404) + ) diff --git a/yasha/openai/api.py b/yasha/openai/api.py index a7a90ef..4227147 100644 --- a/yasha/openai/api.py +++ b/yasha/openai/api.py @@ -17,6 +17,8 @@ EmbeddingRequest, EmbeddingResponse, ErrorResponse, + ImageGenerationRequest, + ImageGenerationResponse, RawSpeechResponse, SpeechRequest, TranscriptionRequest, @@ -75,9 +77,7 @@ def _error_response(result: ErrorResponse) -> JSONResponse: class YashaAPI: def __init__(self, model_handles: dict[str, tuple[DeploymentHandle, ModelUsecase]]): self.models = {name: handle for name, (handle, _) in model_handles.items()} - self.model_list = [ - OpenAiModelCard(id=name) for name, (_, usecase) in model_handles.items() if usecase is ModelUsecase.generate - ] + self.model_list = [OpenAiModelCard(id=name) for name in model_handles] def _get_handle(self, model_name: str | None) -> DeploymentHandle: if model_name is None or model_name not in self.models: @@ -97,7 +97,14 @@ async def _handle_response( watcher.stop() return Response(content=first.audio, media_type=first.media_type) - if isinstance(first, ChatCompletionResponse | EmbeddingResponse | TranscriptionResponse | TranslationResponse): + if isinstance( + first, + ChatCompletionResponse + | EmbeddingResponse + | TranscriptionResponse + | TranslationResponse + | ImageGenerationResponse, + ): watcher.stop() return JSONResponse(content=first.model_dump(mode="json")) @@ -192,3 +199,11 @@ async def create_speech(self, request: SpeechRequest, raw_request: Request): headers = dict(raw_request.headers) response_gen = handle.speak.options(stream=True).remote(request, headers, watcher.event) return await self._handle_response(response_gen, watcher) + + @app.post("/v1/images/generations") + async def create_image(self, request: ImageGenerationRequest, raw_request: Request): + handle = self._get_handle(request.model) + watcher = RequestWatcher(raw_request) + headers = dict(raw_request.headers) + response_gen = handle.imagine.options(stream=True).remote(request, headers, watcher.event) + return await self._handle_response(response_gen, watcher) diff --git a/yasha/openai/protocol.py b/yasha/openai/protocol.py index e043e11..d472642 100644 --- a/yasha/openai/protocol.py +++ b/yasha/openai/protocol.py @@ -85,6 +85,30 @@ class RawSpeechResponse(BaseModel): media_type: Literal["audio/wav"] = Field(default="audio/wav", description="audio bytes media type") +# -- image generation (not yet provided by vllm) --------------------------- +class ImageGenerationRequest(OpenAIBaseModel): + model: str = Field(..., description="The model to use for image generation.") + prompt: str = Field(..., description="A text description of the desired image(s).") + n: int = Field(default=1, ge=1, le=10, description="The number of images to generate.") + size: str = Field(default="512x512", description="The size of the generated images in WxH format.") + response_format: Literal["b64_json"] = Field( + default="b64_json", + description="The format in which the generated images are returned.", + ) + num_inference_steps: int | None = Field(default=None, description="Override default inference steps.") + guidance_scale: float | None = Field(default=None, description="Override default guidance scale.") + + +class ImageObject(OpenAIBaseModel): + b64_json: str = Field(..., description="The base64-encoded JSON of the generated image.") + revised_prompt: str | None = Field(default=None, description="The prompt that was used to generate the image.") + + +class ImageGenerationResponse(OpenAIBaseModel): + created: int = Field(..., description="The Unix timestamp of when the response was created.") + data: list[ImageObject] = Field(..., description="The list of generated images.") + + __all__ = [ "ChatCompletionRequest", "ChatCompletionResponse", @@ -92,6 +116,9 @@ class RawSpeechResponse(BaseModel): "EmbeddingResponse", "ErrorInfo", "ErrorResponse", + "ImageGenerationRequest", + "ImageGenerationResponse", + "ImageObject", "OpenAIBaseModel", "RawSpeechResponse", "SpeechRequest", From 938e81001c6bfacb079517a0320bc42951906f83 Mon Sep 17 00:00:00 2001 From: Alex M Date: Mon, 6 Apr 2026 04:11:36 +0000 Subject: [PATCH 2/2] fix: update tests for mandatory model/loader fields and resolve pyright errors --- tests/test_config.py | 32 +++++++++---------- yasha/infer/diffusers/diffusers_infer.py | 2 +- yasha/infer/diffusers/openai/serving_image.py | 7 ++-- 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index c6eae56..f236e6f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -18,6 +18,7 @@ def test_minimal_vllm_model(self): name="test-llm", model="some-org/some-model", usecase=ModelUsecase.generate, + loader=ModelLoader.vllm, ) assert config.name == "test-llm" assert config.loader == ModelLoader.vllm @@ -46,35 +47,27 @@ def test_custom_loader_with_plugin(self): def test_custom_loader_plugin_only(self): config = YashaModelConfig( name="test-tts", + model="some-model", usecase=ModelUsecase.tts, loader=ModelLoader.custom, plugin="kokoro", ) - assert config.model is None assert config.plugin == "kokoro" - def test_vllm_loader_requires_model(self): - with pytest.raises(ValidationError, match="cannot be both empty"): + def test_model_required(self): + with pytest.raises(ValidationError, match="Field required"): YashaModelConfig( name="test-llm", usecase=ModelUsecase.generate, loader=ModelLoader.vllm, ) - def test_transformers_loader_requires_model(self): - with pytest.raises(ValidationError, match="cannot be both empty"): + def test_loader_required(self): + with pytest.raises(ValidationError, match="Field required"): YashaModelConfig( name="test-llm", + model="some-model", usecase=ModelUsecase.generate, - loader=ModelLoader.transformers, - ) - - def test_model_and_plugin_both_empty_fails(self): - with pytest.raises(ValidationError, match="cannot be both empty"): - YashaModelConfig( - name="test", - usecase=ModelUsecase.generate, - loader=ModelLoader.custom, ) def test_gpu_index_with_tensor_parallelism_fails(self): @@ -83,6 +76,7 @@ def test_gpu_index_with_tensor_parallelism_fails(self): name="test-llm", model="some-model", usecase=ModelUsecase.generate, + loader=ModelLoader.vllm, use_gpu=0, vllm_engine_kwargs=VllmEngineConfig(tensor_parallel_size=2), ) @@ -92,6 +86,7 @@ def test_gpu_index_with_tp1_ok(self): name="test-llm", model="some-model", usecase=ModelUsecase.generate, + loader=ModelLoader.vllm, use_gpu=0, vllm_engine_kwargs=VllmEngineConfig(tensor_parallel_size=1), ) @@ -102,6 +97,7 @@ def test_named_gpu_resource_with_tp(self): name="test-llm", model="some-model", usecase=ModelUsecase.generate, + loader=ModelLoader.vllm, use_gpu="dual_16gb", vllm_engine_kwargs=VllmEngineConfig(tensor_parallel_size=2), ) @@ -112,6 +108,7 @@ def test_gpu_allocation_fraction(self): name="test-llm", model="some-model", usecase=ModelUsecase.generate, + loader=ModelLoader.vllm, num_gpus=0.70, ) assert config.num_gpus == 0.70 @@ -122,16 +119,15 @@ def test_all_usecases_valid(self): name=f"test-{usecase.value}", model="some-model", usecase=usecase, + loader=ModelLoader.vllm, ) assert config.usecase == usecase def test_all_loaders_valid(self): for loader in ModelLoader: - kwargs = {"name": "test", "usecase": ModelUsecase.generate} + kwargs = {"name": "test", "model": "some-model", "usecase": ModelUsecase.generate} if loader == ModelLoader.custom: kwargs["plugin"] = "test-plugin" - else: - kwargs["model"] = "some-model" config = YashaModelConfig(loader=loader, **kwargs) assert config.loader == loader @@ -165,10 +161,12 @@ def test_multi_model_config(self): name="llm", model="some-org/some-llm", usecase=ModelUsecase.generate, + loader=ModelLoader.vllm, num_gpus=0.70, ), YashaModelConfig( name="tts", + model="some-model", usecase=ModelUsecase.tts, loader=ModelLoader.custom, plugin="kokoro", diff --git a/yasha/infer/diffusers/diffusers_infer.py b/yasha/infer/diffusers/diffusers_infer.py index eb1902b..c2ac396 100644 --- a/yasha/infer/diffusers/diffusers_infer.py +++ b/yasha/infer/diffusers/diffusers_infer.py @@ -45,7 +45,7 @@ def __del__(self): pass async def start(self): - from diffusers import AutoPipelineForText2Image + from diffusers.pipelines.auto_pipeline import AutoPipelineForText2Image config = self.model_config.diffusers_config or DiffusersConfig() dtype = _TORCH_DTYPES.get(config.torch_dtype, torch.float16) diff --git a/yasha/infer/diffusers/openai/serving_image.py b/yasha/infer/diffusers/openai/serving_image.py index e339ca1..a8670fb 100644 --- a/yasha/infer/diffusers/openai/serving_image.py +++ b/yasha/infer/diffusers/openai/serving_image.py @@ -4,11 +4,12 @@ import logging import time -from diffusers import AutoPipelineForText2Image +from diffusers.pipelines.auto_pipeline import AutoPipelineForText2Image from fastapi import Request from yasha.infer.infer_config import DiffusersConfig from yasha.openai.protocol import ( + ErrorResponse, ImageGenerationRequest, ImageGenerationResponse, ImageObject, @@ -28,7 +29,7 @@ def __init__(self, pipeline: AutoPipelineForText2Image, config: DiffusersConfig) async def create_image_generation( self, request: ImageGenerationRequest, raw_request: Request - ) -> ImageGenerationResponse: + ) -> ImageGenerationResponse | ErrorResponse: request_id = f"{self.request_id_prefix}-{base_request_id(raw_request)}" logger.info( "image generation request %s: prompt=%r, n=%d, size=%s", request_id, request.prompt, request.n, request.size @@ -46,7 +47,7 @@ async def create_image_generation( images = await loop.run_in_executor( None, lambda: ( - self.pipeline( + self.pipeline( # type: ignore[reportCallIssue] prompt=request.prompt, num_images_per_prompt=request.n, width=width,