diff --git a/yasha/infer/base_infer.py b/yasha/infer/base_infer.py index bd60242..c520e47 100644 --- a/yasha/infer/base_infer.py +++ b/yasha/infer/base_infer.py @@ -1,4 +1,5 @@ import logging +import struct from abc import ABC, abstractmethod from collections.abc import AsyncGenerator @@ -27,6 +28,25 @@ error=ErrorInfo(message="model does not support this action", type="invalid_request_error", code=404) ) +# 44-byte WAV header + 2 bytes of silence (one 16-bit sample at 16 kHz mono) +_MINIMAL_WAV_HEADER = struct.pack( + "<4sI4s4sIHHIIHH4sI", + b"RIFF", + 36 + 2, + b"WAVE", # RIFF chunk + b"fmt ", + 16, + 1, + 1, + 16000, + 32000, + 2, + 16, # fmt sub-chunk: PCM, mono, 16 kHz, 16-bit + b"data", + 2, # data sub-chunk: 2 bytes +) +MINIMAL_WAV = _MINIMAL_WAV_HEADER + b"\x00\x00" + class BaseInfer(ABC): def __init__(self, model_config: YashaModelConfig): @@ -46,6 +66,15 @@ def _set_max_context_length(self, length: int | None) -> None: @abstractmethod async def start(self) -> None: ... + @abstractmethod + async def warmup(self) -> None: + """Run a minimal inference pass to warm up the model (CUDA kernels, caches, etc.). + + Subclasses should override this to send a tiny dummy request through + their actual inference path. The default is a no-op for loaders that + don't need warmup. + """ + async def create_chat_completion( self, request: ChatCompletionRequest, raw_request: DisconnectProxy ) -> ErrorResponse | ChatCompletionResponse | AsyncGenerator[str, None]: diff --git a/yasha/infer/custom/custom_infer.py b/yasha/infer/custom/custom_infer.py index e15ea42..fa59a6b 100644 --- a/yasha/infer/custom/custom_infer.py +++ b/yasha/infer/custom/custom_infer.py @@ -34,6 +34,17 @@ async def start(self): self.serving_speech = await self.init_serving_speech() + async def warmup(self) -> None: + if self.serving_speech is None: + return + logger.info("Warming up custom TTS model: %s", self.model_config.name) + request = SpeechRequest(model=self.model_config.name, input="warmup", voice="default") + result = await self.create_speech(request, DisconnectProxy(None, {})) + if isinstance(result, AsyncGenerator): + async for _ in result: + pass + logger.info("Warmup TTS done for %s", self.model_config.name) + async def init_serving_speech(self) -> OpenAIServingSpeech | None: logger.info("init serving speech with model: %s", self.model_config.name) return ( diff --git a/yasha/infer/diffusers/diffusers_infer.py b/yasha/infer/diffusers/diffusers_infer.py index 4eab810..1413fcb 100644 --- a/yasha/infer/diffusers/diffusers_infer.py +++ b/yasha/infer/diffusers/diffusers_infer.py @@ -69,6 +69,21 @@ async def start(self): else None ) + async def warmup(self) -> None: + if self.serving_image is None: + return + logger.info("Warming up diffusers model: %s", self.model_config.name) + request = ImageGenerationRequest( + model=self.model_config.name, + prompt="warmup", + n=1, + size="64x64", + num_inference_steps=1, + guidance_scale=0.0, + ) + await self.create_image_generation(request, DisconnectProxy(None, {})) + logger.info("Warmup image generation done for %s", self.model_config.name) + async def create_image_generation( self, request: ImageGenerationRequest, raw_request: DisconnectProxy ) -> ErrorResponse | ImageGenerationResponse: diff --git a/yasha/infer/model_deployment.py b/yasha/infer/model_deployment.py index d1dd8bb..03c3b6c 100644 --- a/yasha/infer/model_deployment.py +++ b/yasha/infer/model_deployment.py @@ -49,6 +49,7 @@ async def __init__(self, config: YashaModelConfig): self.infer = CustomInfer(config) await self.infer.start() + await self.infer.warmup() except Exception: MODEL_LOAD_FAILURES_TOTAL.inc(tags={"model": config.name, "loader": config.loader.value}) raise diff --git a/yasha/infer/transformers/transformers_infer.py b/yasha/infer/transformers/transformers_infer.py index 34545c6..8c3944a 100644 --- a/yasha/infer/transformers/transformers_infer.py +++ b/yasha/infer/transformers/transformers_infer.py @@ -59,6 +59,17 @@ async def init_serving_speech(self) -> OpenAIServingSpeech | None: return OpenAIServingSpeech(speech_model=speech_model) if self.model_config.usecase is ModelUsecase.tts else None + async def warmup(self) -> None: + if self.serving_speech is None: + return + logger.info("Warming up transformers TTS model: %s", self.model_config.name) + request = SpeechRequest(model=self.model_config.name, input="warmup", voice="default") + result = await self.create_speech(request, DisconnectProxy(None, {})) + if isinstance(result, AsyncGenerator): + async for _ in result: + pass + logger.info("Warmup TTS done for %s", self.model_config.name) + async def create_speech( self, request: SpeechRequest, raw_request: DisconnectProxy ) -> ErrorResponse | RawSpeechResponse | AsyncGenerator[str, None]: diff --git a/yasha/infer/vllm/vllm_infer.py b/yasha/infer/vllm/vllm_infer.py index 5085fc0..b826049 100644 --- a/yasha/infer/vllm/vllm_infer.py +++ b/yasha/infer/vllm/vllm_infer.py @@ -1,7 +1,9 @@ +import io import logging from collections.abc import AsyncGenerator from typing import ClassVar, cast +from fastapi import UploadFile from starlette.requests import Request from starlette.responses import Response from vllm.config.model import ModelDType @@ -17,12 +19,13 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.async_llm import AsyncLLM -from yasha.infer.base_infer import BaseInfer +from yasha.infer.base_infer import MINIMAL_WAV, BaseInfer from yasha.infer.infer_config import DisconnectProxy, ModelUsecase, VllmEngineConfig, YashaModelConfig from yasha.metrics import _ENABLED as _METRICS_ENABLED from yasha.openai.protocol import ( ChatCompletionRequest, ChatCompletionResponse, + EmbeddingCompletionRequest, EmbeddingRequest, ErrorResponse, TranscriptionRequest, @@ -112,6 +115,50 @@ async def start(self): self.serving_transcription = await self.init_serving_transcription() self.serving_translation = await self.init_serving_translation() + async def warmup(self) -> None: + logger.info("Warming up vllm model: %s", self.model_config.name) + dummy_proxy = DisconnectProxy(None, {}) + + if self.serving_chat is not None: + request = ChatCompletionRequest( + model=self.model_config.name, messages=[{"role": "user", "content": "warmup"}], max_tokens=1, seed=-1 + ) + result = await self.create_chat_completion(request, dummy_proxy) + if isinstance(result, AsyncGenerator): + async for _ in result: + pass + logger.info("Warmup chat completion done for %s", self.model_config.name) + + elif self.serving_embedding is not None: + request = EmbeddingCompletionRequest( + model=self.model_config.name, + input="warmup", + ) + await self.create_embedding(request, dummy_proxy) + logger.info("Warmup embedding done for %s", self.model_config.name) + + elif self.serving_transcription is not None: + request = TranscriptionRequest( + model=self.model_config.name, file=UploadFile(file=io.BytesIO(MINIMAL_WAV)), seed=-1 + ) + audio_data = MINIMAL_WAV + result = await self.create_transcription(audio_data, request, dummy_proxy) + if isinstance(result, AsyncGenerator): + async for _ in result: + pass + logger.info("Warmup transcription done for %s", self.model_config.name) + + elif self.serving_translation is not None: + request = TranslationRequest( + model=self.model_config.name, file=UploadFile(file=io.BytesIO(MINIMAL_WAV)), seed=-1 + ) + audio_data = MINIMAL_WAV + result = await self.create_translation(audio_data, request, dummy_proxy) + if isinstance(result, AsyncGenerator): + async for _ in result: + pass + logger.info("Warmup translation done for %s", self.model_config.name) + async def init_serving_chat(self) -> OpenAIServingChat | None: logger.info("init_serving_chat: %s, %s", self.supported_tasks, self.model_config.usecase) if not (self.model_config.usecase is ModelUsecase.generate and "generate" in self.supported_tasks): diff --git a/yasha/openai/protocol.py b/yasha/openai/protocol.py index d472642..835840a 100644 --- a/yasha/openai/protocol.py +++ b/yasha/openai/protocol.py @@ -37,6 +37,7 @@ # -- embeddings ------------------------------------------------------------- from vllm.entrypoints.pooling.embed.protocol import ( + EmbeddingCompletionRequest, EmbeddingRequest, EmbeddingResponse, ) @@ -112,6 +113,7 @@ class ImageGenerationResponse(OpenAIBaseModel): __all__ = [ "ChatCompletionRequest", "ChatCompletionResponse", + "EmbeddingCompletionRequest", "EmbeddingRequest", "EmbeddingResponse", "ErrorInfo",