Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions yasha/infer/base_infer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import struct
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator

Expand Down Expand Up @@ -27,6 +28,25 @@
error=ErrorInfo(message="model does not support this action", type="invalid_request_error", code=404)
)

# 44-byte WAV header + 2 bytes of silence (one 16-bit sample at 16 kHz mono)
_MINIMAL_WAV_HEADER = struct.pack(
"<4sI4s4sIHHIIHH4sI",
b"RIFF",
36 + 2,
b"WAVE", # RIFF chunk
b"fmt ",
16,
1,
1,
16000,
32000,
2,
16, # fmt sub-chunk: PCM, mono, 16 kHz, 16-bit
b"data",
2, # data sub-chunk: 2 bytes
)
MINIMAL_WAV = _MINIMAL_WAV_HEADER + b"\x00\x00"


class BaseInfer(ABC):
def __init__(self, model_config: YashaModelConfig):
Expand All @@ -46,6 +66,15 @@ def _set_max_context_length(self, length: int | None) -> None:
@abstractmethod
async def start(self) -> None: ...

@abstractmethod
async def warmup(self) -> None:
"""Run a minimal inference pass to warm up the model (CUDA kernels, caches, etc.).

Subclasses should override this to send a tiny dummy request through
their actual inference path. The default is a no-op for loaders that
don't need warmup.
"""

async def create_chat_completion(
self, request: ChatCompletionRequest, raw_request: DisconnectProxy
) -> ErrorResponse | ChatCompletionResponse | AsyncGenerator[str, None]:
Expand Down
11 changes: 11 additions & 0 deletions yasha/infer/custom/custom_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ async def start(self):

self.serving_speech = await self.init_serving_speech()

async def warmup(self) -> None:
if self.serving_speech is None:
return
logger.info("Warming up custom TTS model: %s", self.model_config.name)
request = SpeechRequest(model=self.model_config.name, input="warmup", voice="default")
result = await self.create_speech(request, DisconnectProxy(None, {}))
if isinstance(result, AsyncGenerator):
async for _ in result:
pass
logger.info("Warmup TTS done for %s", self.model_config.name)

async def init_serving_speech(self) -> OpenAIServingSpeech | None:
logger.info("init serving speech with model: %s", self.model_config.name)
return (
Expand Down
15 changes: 15 additions & 0 deletions yasha/infer/diffusers/diffusers_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ async def start(self):
else None
)

async def warmup(self) -> None:
if self.serving_image is None:
return
logger.info("Warming up diffusers model: %s", self.model_config.name)
request = ImageGenerationRequest(
model=self.model_config.name,
prompt="warmup",
n=1,
size="64x64",
num_inference_steps=1,
guidance_scale=0.0,
)
await self.create_image_generation(request, DisconnectProxy(None, {}))
logger.info("Warmup image generation done for %s", self.model_config.name)

async def create_image_generation(
self, request: ImageGenerationRequest, raw_request: DisconnectProxy
) -> ErrorResponse | ImageGenerationResponse:
Expand Down
1 change: 1 addition & 0 deletions yasha/infer/model_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ async def __init__(self, config: YashaModelConfig):
self.infer = CustomInfer(config)

await self.infer.start()
await self.infer.warmup()
except Exception:
MODEL_LOAD_FAILURES_TOTAL.inc(tags={"model": config.name, "loader": config.loader.value})
raise
Expand Down
11 changes: 11 additions & 0 deletions yasha/infer/transformers/transformers_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,17 @@ async def init_serving_speech(self) -> OpenAIServingSpeech | None:

return OpenAIServingSpeech(speech_model=speech_model) if self.model_config.usecase is ModelUsecase.tts else None

async def warmup(self) -> None:
if self.serving_speech is None:
return
logger.info("Warming up transformers TTS model: %s", self.model_config.name)
request = SpeechRequest(model=self.model_config.name, input="warmup", voice="default")
result = await self.create_speech(request, DisconnectProxy(None, {}))
if isinstance(result, AsyncGenerator):
async for _ in result:
pass
logger.info("Warmup TTS done for %s", self.model_config.name)

async def create_speech(
self, request: SpeechRequest, raw_request: DisconnectProxy
) -> ErrorResponse | RawSpeechResponse | AsyncGenerator[str, None]:
Expand Down
49 changes: 48 additions & 1 deletion yasha/infer/vllm/vllm_infer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import io
import logging
from collections.abc import AsyncGenerator
from typing import ClassVar, cast

from fastapi import UploadFile
from starlette.requests import Request
from starlette.responses import Response
from vllm.config.model import ModelDType
Expand All @@ -17,12 +19,13 @@
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM

from yasha.infer.base_infer import BaseInfer
from yasha.infer.base_infer import MINIMAL_WAV, BaseInfer
from yasha.infer.infer_config import DisconnectProxy, ModelUsecase, VllmEngineConfig, YashaModelConfig
from yasha.metrics import _ENABLED as _METRICS_ENABLED
from yasha.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
EmbeddingCompletionRequest,
EmbeddingRequest,
ErrorResponse,
TranscriptionRequest,
Expand Down Expand Up @@ -112,6 +115,50 @@ async def start(self):
self.serving_transcription = await self.init_serving_transcription()
self.serving_translation = await self.init_serving_translation()

async def warmup(self) -> None:
logger.info("Warming up vllm model: %s", self.model_config.name)
dummy_proxy = DisconnectProxy(None, {})

if self.serving_chat is not None:
request = ChatCompletionRequest(
model=self.model_config.name, messages=[{"role": "user", "content": "warmup"}], max_tokens=1, seed=-1
)
result = await self.create_chat_completion(request, dummy_proxy)
if isinstance(result, AsyncGenerator):
async for _ in result:
pass
logger.info("Warmup chat completion done for %s", self.model_config.name)

elif self.serving_embedding is not None:
request = EmbeddingCompletionRequest(
model=self.model_config.name,
input="warmup",
)
await self.create_embedding(request, dummy_proxy)
logger.info("Warmup embedding done for %s", self.model_config.name)

elif self.serving_transcription is not None:
request = TranscriptionRequest(
model=self.model_config.name, file=UploadFile(file=io.BytesIO(MINIMAL_WAV)), seed=-1
)
audio_data = MINIMAL_WAV
result = await self.create_transcription(audio_data, request, dummy_proxy)
if isinstance(result, AsyncGenerator):
async for _ in result:
pass
logger.info("Warmup transcription done for %s", self.model_config.name)

elif self.serving_translation is not None:
request = TranslationRequest(
model=self.model_config.name, file=UploadFile(file=io.BytesIO(MINIMAL_WAV)), seed=-1
)
audio_data = MINIMAL_WAV
result = await self.create_translation(audio_data, request, dummy_proxy)
if isinstance(result, AsyncGenerator):
async for _ in result:
pass
logger.info("Warmup translation done for %s", self.model_config.name)

async def init_serving_chat(self) -> OpenAIServingChat | None:
logger.info("init_serving_chat: %s, %s", self.supported_tasks, self.model_config.usecase)
if not (self.model_config.usecase is ModelUsecase.generate and "generate" in self.supported_tasks):
Expand Down
2 changes: 2 additions & 0 deletions yasha/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

# -- embeddings -------------------------------------------------------------
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
)
Expand Down Expand Up @@ -112,6 +113,7 @@ class ImageGenerationResponse(OpenAIBaseModel):
__all__ = [
"ChatCompletionRequest",
"ChatCompletionResponse",
"EmbeddingCompletionRequest",
"EmbeddingRequest",
"EmbeddingResponse",
"ErrorInfo",
Expand Down
Loading