diff --git a/src/eva/assistant/pipeline/nvidia_stt.py b/src/eva/assistant/pipeline/nvidia_stt.py index 1e6b0974..50cc920b 100644 --- a/src/eva/assistant/pipeline/nvidia_stt.py +++ b/src/eva/assistant/pipeline/nvidia_stt.py @@ -1,31 +1,63 @@ """NVIDIA Parakeet streaming speech-to-text service implementation. -Follows the same pattern as Pipecat's built-in AssemblyAI STT service. -The subclass only handles server-specific protocol (connection, audio format, -message parsing). All VAD, TTFB metrics, and finalization are handled by the -WebsocketSTTService base class. +Audio gating strategy — bot-speaking gate: + - The audio gate is OPEN by default. Audio flows to Parakeet whenever the + bot is not speaking. + - The gate CLOSES on BotStartedSpeakingFrame. Any buffered transcript + parts are discarded so that stale Parakeet completions from the + inter-turn silence period do not bleed into the next user turn. + - The gate OPENS on BotStoppedSpeakingFrame, resuming normal audio flow. + - A keepalive sends silent audio during long bot-speech turns to prevent + the Parakeet WebSocket from closing. + +Finalization (VAD-primary, Parakeet-fallback): + - When VAD fires stop, finalize immediately or wait for the next + ``completed`` event (primary path). + - If Parakeet emits a non-empty ``completed`` and VAD has NOT fired, a + fallback timer starts. If VAD still hasn't fired when the timer + expires, we auto-finalize using Parakeet's transcript — this handles + the case where Silero VAD misses a short utterance. """ import asyncio +import base64 import json import ssl import time from collections.abc import AsyncGenerator +from urllib.parse import urlparse +import httpx import websockets -from loguru import logger from pipecat.frames.frames import ( + AudioRawFrame, + BotStartedSpeakingFrame, + BotStoppedSpeakingFrame, CancelFrame, EndFrame, Frame, InterimTranscriptionFrame, StartFrame, TranscriptionFrame, + VADUserStartedSpeakingFrame, VADUserStoppedSpeakingFrame, ) from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.settings import STTSettings from pipecat.services.stt_service import WebsocketSTTService +from eva.utils.logging import get_logger + +logger = get_logger(__name__) + +# Seconds after VAD stop to wait for a `completed` before force-finalizing. +_FINALIZE_TIMEOUT_SECS = 3.0 + +# Seconds after a Parakeet `completed` (with no VAD) before auto-finalizing. +# Gives VAD a chance to catch up; if it doesn't, Parakeet's own sentence +# detection serves as the fallback signal. +_FALLBACK_FINALIZE_SECS = 1.5 + def current_time_ms(): return str(int(round(time.time() * 1000))) @@ -37,11 +69,11 @@ class NVidiaWebSocketSTTService(WebsocketSTTService): Provides real-time speech recognition using NVIDIA's Parakeet ASR model via WebSocket. - Server protocol: - - Audio in: 16-bit PCM, 16kHz, mono (raw bytes) - - Reset in: {"type": "reset", "finalize": true} (triggers final transcript) - - Ready out: {"type": "ready"} - - Transcript out: {"type": "transcript", "text": "...", "is_final": true/false} + Server protocol (OpenAI Realtime API): + - Audio in: {"type": "input_audio_buffer.append", "audio": ""} + - Commit in: {"type": "input_audio_buffer.commit"} + - Ready out: {"type": "conversation.created"} + - Transcript out: {"type": "conversation.item.input_audio_transcription.completed", ...} """ def __init__( @@ -51,20 +83,37 @@ def __init__( api_key: str | None = None, sample_rate: int = 16000, verify: bool = True, + model: str | None = None, **kwargs, ): - super().__init__(sample_rate=sample_rate, **kwargs) + super().__init__( + sample_rate=sample_rate, + settings=STTSettings(model=None, language=None), + # Send a silent keepalive every 10s after 15s of no audio, so the + # Parakeet WebSocket doesn't close during long bot-speech turns. + keepalive_timeout=15.0, + keepalive_interval=10.0, + **kwargs, + ) self._url = url self._api_key = api_key self._verify = verify + self._asr_model = None self._websocket = None self._receive_task: asyncio.Task | None = None self._ready = False + # Gate starts OPEN — audio flows to Parakeet by default. + # Only closed while the bot is speaking. + self._audio_gate_open = True + self._finalize_requested = False + self._finalize_timeout_task: asyncio.Task | None = None + self._fallback_finalize_task: asyncio.Task | None = None + self._transcript_parts: list[str] = [] def can_generate_metrics(self) -> bool: return True - # -- Lifecycle (matches AssemblyAI pattern exactly) -- + # -- Lifecycle -- async def start(self, frame: StartFrame): await super().start(frame) @@ -78,29 +127,158 @@ async def cancel(self, frame: CancelFrame): await super().cancel(frame) await self._disconnect() - # -- Audio sending -- + # -- Audio processing -- + + _audio_chunk_count: int = 0 + + async def process_audio_frame(self, frame: AudioRawFrame, direction: FrameDirection): + """Override base class to only reset keepalive timer when actually sending. + + The base STTService.process_audio_frame unconditionally resets + ``_last_audio_time`` on every audio frame — including silence during + bot speech. This prevents the keepalive from ever firing, so the + Parakeet WebSocket dies during long bot turns. + + When the gate is closed (bot speaking) we skip the base-class call + entirely so the keepalive timer keeps ticking. + """ + if self._muted: + return + + if self._audio_gate_open: + # Gate open — let the base class update _last_audio_time + # and call run_stt normally. + await super().process_audio_frame(frame, direction) + # Gate closed (bot speaking) — don't touch _last_audio_time so the + # keepalive timer keeps ticking. Audio is intentionally discarded. + + async def _send_audio(self, audio: bytes): + """Send a single audio chunk to Parakeet (append + commit).""" + try: + await self._websocket.send( + json.dumps({"type": "input_audio_buffer.append", "audio": base64.b64encode(audio).decode("ascii")}) + ) + await self._websocket.send(json.dumps({"type": "input_audio_buffer.commit"})) + self._audio_chunk_count += 1 + if self._audio_chunk_count % 50 == 1: + logger.debug(f"{self} sent audio chunk #{self._audio_chunk_count} ({len(audio)} bytes)") + except Exception as e: + logger.error(f"{self} failed to send audio: {e}") async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: - if self._websocket and self._ready: - try: - await self._websocket.send(audio) - except Exception as e: - logger.error(f"{self} failed to send audio: {e}") + if not self._websocket or not self._ready: + if not self._ready: + logger.warning(f"{self} audio dropped — not ready") + yield None + return + + await self._send_audio(audio) yield None - # -- VAD handling (send reset on speech end, like AssemblyAI's ForceEndpoint) -- + # -- Keepalive -- + + async def _send_keepalive(self, silence: bytes): + """Wrap silent PCM in Parakeet's append+commit protocol.""" + logger.debug(f"{self} sending keepalive silence ({len(silence)} bytes)") + await self._send_audio(silence) + + # -- Frame handling (bot-speaking gate + VAD finalization) -- async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) - if isinstance(frame, VADUserStoppedSpeakingFrame): - if self._websocket and self._ready: - self.request_finalize() - try: - await self._websocket.send(json.dumps({"type": "reset", "finalize": True})) - except Exception as e: - logger.error(f"{self} failed to send reset: {e}") + # --- Bot-speaking gate --- + if isinstance(frame, BotStartedSpeakingFrame): + self._audio_gate_open = False + # Discard any stale transcript parts so old Parakeet completions + # from the inter-turn silence period don't bleed into the next turn. + self._transcript_parts.clear() + await self._cancel_fallback_finalize() + logger.debug(f"{self} audio gate CLOSED (bot speaking)") + elif isinstance(frame, BotStoppedSpeakingFrame): + self._audio_gate_open = True + logger.debug(f"{self} audio gate OPEN (bot stopped)") + + # --- VAD-based finalization (primary path) --- + elif isinstance(frame, VADUserStartedSpeakingFrame): + # VAD detected speech — cancel any fallback timer since VAD is + # now in control of finalization. + await self._cancel_fallback_finalize() + elif isinstance(frame, VADUserStoppedSpeakingFrame): + await self._cancel_fallback_finalize() + self._finalize_requested = True + self.request_finalize() await self.start_processing_metrics() + if self._transcript_parts: + await self._emit_final_transcript() + else: + # Start a safety timeout — if Parakeet doesn't send `completed` + # within a few seconds, force-finalize. + self._start_finalize_timeout() + + # -- Finalize timeout (VAD fired but no completed from Parakeet) -- + + def _start_finalize_timeout(self): + """Start (or restart) the finalize safety timeout.""" + if self._finalize_timeout_task: + self._finalize_timeout_task.cancel() + self._finalize_timeout_task = self.create_task(self._finalize_timeout_handler()) + + async def _cancel_finalize_timeout(self): + """Cancel any pending finalize timeout.""" + if self._finalize_timeout_task: + await self.cancel_task(self._finalize_timeout_task) + self._finalize_timeout_task = None + + async def _finalize_timeout_handler(self): + """Force-finalize after trailing silence timeout.""" + await asyncio.sleep(_FINALIZE_TIMEOUT_SECS) + if self._finalize_requested: + logger.warning(f"{self} finalize timeout after {_FINALIZE_TIMEOUT_SECS}s — force-finalizing") + if self._transcript_parts: + await self._emit_final_transcript() + else: + # Ghost turn — no transcript arrived. + self._finalize_requested = False + self.confirm_finalize() + + # -- Fallback finalize (Parakeet completed but VAD never fired) -- + + def _start_fallback_finalize(self): + """Start a fallback timer to auto-finalize if VAD doesn't fire.""" + if self._fallback_finalize_task: + self._fallback_finalize_task.cancel() + self._fallback_finalize_task = self.create_task(self._fallback_finalize_handler()) + + async def _cancel_fallback_finalize(self): + """Cancel the fallback finalize timer.""" + if self._fallback_finalize_task: + await self.cancel_task(self._fallback_finalize_task) + self._fallback_finalize_task = None + + async def _fallback_finalize_handler(self): + """Auto-finalize using Parakeet's transcript when VAD missed the speech. + + Because VAD never fired, the downstream LLMUserAggregator has no + active user turn. We push synthetic VAD start/stop frames so the + aggregator sees a proper turn lifecycle and triggers the LLM. + """ + await asyncio.sleep(_FALLBACK_FINALIZE_SECS) + if self._transcript_parts and not self._finalize_requested: + logger.warning( + f"{self} VAD miss — fallback finalizing with Parakeet transcript after {_FALLBACK_FINALIZE_SECS}s" + ) + # Push synthetic VAD start so the aggregator opens a user turn. + await self.push_frame(VADUserStartedSpeakingFrame()) + + self._finalize_requested = True + self.request_finalize() + await self.start_processing_metrics() + await self._emit_final_transcript() + + # Push synthetic VAD stop so the aggregator closes the turn + # and triggers the LLM. + await self.push_frame(VADUserStoppedSpeakingFrame()) # -- Connection management -- @@ -113,6 +291,8 @@ async def _connect(self): async def _disconnect(self): await super()._disconnect() + await self._cancel_finalize_timeout() + await self._cancel_fallback_finalize() if self._receive_task: await self.cancel_task(self._receive_task) @@ -139,16 +319,16 @@ async def _connect_websocket(self): ) self._ready = False - # Wait for ready message from server try: + logger.info(f"Connecting to {self._url}") ready_msg = await asyncio.wait_for(self._websocket.recv(), timeout=5.0) data = json.loads(ready_msg) - if data.get("type") == "ready": - self._ready = True - logger.info(f"{self} connected and ready") + if data.get("type") == "conversation.created": + logger.info("Conversation created successfully") + await self._configure_session() else: logger.warning(f"{self} unexpected initial message: {data}") - self._ready = True + self._ready = True except TimeoutError: logger.warning(f"{self} timeout waiting for ready, proceeding") self._ready = True @@ -159,6 +339,51 @@ async def _connect_websocket(self): logger.error(f"{self} connection failed: {e}") raise + async def _initialize_http_session(self) -> dict: + """Initialize session via HTTP POST to get server defaults (model, sample rate, etc.).""" + parsed = urlparse(self._url) + scheme = "https" if parsed.scheme == "wss" else "http" + http_url = f"{scheme}://{parsed.hostname}" + if parsed.port: + http_url += f":{parsed.port}" + http_url += "/v1/realtime/transcription_sessions" + + headers = {"Content-Type": "application/json"} + if self._api_key: + headers["Authorization"] = f"Bearer {self._api_key}" + + async with httpx.AsyncClient(verify=self._verify) as client: + response = await client.post(http_url, headers=headers, json={}) + response.raise_for_status() + session_data = response.json() + return session_data + + async def _configure_session(self): + """Get server defaults via HTTP, then send transcription_session.update over WS.""" + try: + session_config = await self._initialize_http_session() + except Exception as e: + logger.warning(f"{self} HTTP session init failed ({e}), using minimal config") + session_config = {} + + session_config["input_audio_format"] = "pcm16" + + if self._asr_model: + session_config.setdefault("input_audio_transcription", {}) + session_config["input_audio_transcription"]["model"] = self._asr_model + + await self._websocket.send(json.dumps({"type": "transcription_session.update", "session": session_config})) + + try: + response = await asyncio.wait_for(self._websocket.recv(), timeout=5.0) + data = json.loads(response) + if data.get("type") == "transcription_session.updated": + logger.info(f"{self} session configured: {data.get('session', {})}") + else: + logger.warning(f"{self} unexpected session update response: {data}") + except TimeoutError: + logger.warning(f"{self} timeout waiting for session update confirmation") + async def _disconnect_websocket(self): self._ready = False if self._websocket: @@ -181,33 +406,59 @@ async def _receive_messages(self): data = json.loads(message) msg_type = data.get("type") - if msg_type == "transcript": - await self._handle_transcript(data) - elif msg_type == "ready": - self._ready = True - elif msg_type == "error": - logger.error(f"{self} server error: {data.get('message')}") + if msg_type == "error": + logger.error(f"{self} server error: {data}") + elif msg_type == "conversation.item.input_audio_transcription.delta": + delta = data.get("delta", "") + if delta: + await self.push_frame( + InterimTranscriptionFrame(delta, self._user_id, current_time_ms(), language=None) + ) + elif msg_type == "conversation.item.input_audio_transcription.completed": + await self._handle_completed(data) except json.JSONDecodeError: logger.warning(f"{self} non-JSON message received") except Exception as e: logger.error(f"{self} error processing message: {e}") - async def _handle_transcript(self, data: dict): - text = data.get("text", "") - is_final = data.get("is_final", False) - - if not text: - # Empty reset response (ghost turn). Push empty finalized - # TranscriptionFrame so the aggregator resolves immediately. - if is_final: - logger.debug(f"{self} empty final transcript (ghost turn)") + async def _handle_completed(self, data: dict): + """Handle a server-side sentence completion event.""" + transcript = data.get("transcript", "").strip() + + if transcript: + self._transcript_parts.append(transcript) + if self._finalize_requested: + # VAD already fired — finalize immediately. + await self._emit_final_transcript() + else: + # VAD hasn't fired yet. Push as interim and start the + # fallback timer so we auto-finalize if VAD never fires. + logger.debug(f"{self} buffered (no VAD yet): {transcript}") + await self.push_frame( + InterimTranscriptionFrame(transcript, self._user_id, current_time_ms(), language=None) + ) + self._start_fallback_finalize() + elif self._finalize_requested: + # Empty completed after VAD fired (silence audio). + if self._transcript_parts: + await self._emit_final_transcript() + else: + logger.debug(f"{self} ghost turn (empty completed)") + self._finalize_requested = False + await self._cancel_finalize_timeout() self.confirm_finalize() - return - if is_final: - self.confirm_finalize() - await self.push_frame(TranscriptionFrame(text, self._user_id, current_time_ms(), language=None)) - await self.stop_processing_metrics() - else: - await self.push_frame(InterimTranscriptionFrame(text, self._user_id, current_time_ms(), language=None)) + async def _emit_final_transcript(self): + """Flush accumulated transcript parts and emit a finalized TranscriptionFrame.""" + full_transcript = " ".join(self._transcript_parts) + self._transcript_parts = [] + self._finalize_requested = False + await self._cancel_finalize_timeout() + await self._cancel_fallback_finalize() + logger.info(f"{self} final transcript: {full_transcript}") + self.confirm_finalize() + await self.push_frame( + TranscriptionFrame(full_transcript, self._user_id, current_time_ms(), language=None, finalized=True) + ) + await self.stop_processing_metrics() diff --git a/src/eva/assistant/pipeline/services.py b/src/eva/assistant/pipeline/services.py index b3fb4f96..70201404 100644 --- a/src/eva/assistant/pipeline/services.py +++ b/src/eva/assistant/pipeline/services.py @@ -185,6 +185,8 @@ def create_stt_service( api_key=api_key, sample_rate=params.get("sample_rate", SAMPLE_RATE), verify=False, + model=params.get("model"), + language=None, ) elif model_lower == "nvidia-baseten": diff --git a/src/eva/assistant/pipeline/turn_config.py b/src/eva/assistant/pipeline/turn_config.py index efe71a0e..e4ec3083 100644 --- a/src/eva/assistant/pipeline/turn_config.py +++ b/src/eva/assistant/pipeline/turn_config.py @@ -111,10 +111,12 @@ def create_turn_stop_strategy( return SpeechTimeoutUserTurnStopStrategy(**strategy_params) elif strategy_type_lower == "turn_analyzer": # TurnAnalyzerUserTurnStopStrategy requires a turn_analyzer instance - # If smart_turn_stop_secs is provided, use it; otherwise let SmartTurnParams use its default - smart_params = SmartTurnParams(stop_secs=smart_turn_stop_secs) if smart_turn_stop_secs is not None else None + # smart_turn_stop_secs can be passed via strategy_params (takes precedence) or the explicit argument + params = dict(strategy_params) + stop_secs = params.pop("smart_turn_stop_secs", smart_turn_stop_secs) + smart_params = SmartTurnParams(stop_secs=stop_secs) if stop_secs is not None else None turn_analyzer = LocalSmartTurnAnalyzerV3(params=smart_params) - return TurnAnalyzerUserTurnStopStrategy(turn_analyzer=turn_analyzer, **strategy_params) + return TurnAnalyzerUserTurnStopStrategy(turn_analyzer=turn_analyzer, **params) elif strategy_type_lower == "external": # ExternalUserTurnStopStrategy has no required parameters return ExternalUserTurnStopStrategy(**strategy_params) diff --git a/src/eva/user_simulator/client.py b/src/eva/user_simulator/client.py index bb554300..31be69e5 100644 --- a/src/eva/user_simulator/client.py +++ b/src/eva/user_simulator/client.py @@ -23,7 +23,7 @@ from eva.user_simulator.event_logger import ElevenLabsEventLogger from eva.user_simulator.perturbation import AudioPerturbator from eva.utils.audio_utils import save_pcm_as_wav -from eva.utils.logging import get_logger +from eva.utils.logging import current_record_id, get_logger from eva.utils.prompt_manager import PromptManager logger = get_logger(__name__) @@ -103,6 +103,10 @@ def __init__( self._consecutive_keepalive_count = 0 self._max_consecutive_keepalives = 12 # End call after this many pings without activity (2 minutes) + # Capture the worker's record ID so ElevenLabs callbacks (which run in + # a different thread) can restore it for per-record log routing. + self._record_id = current_record_id.get() + def _on_conversation_end(self, reason: str = "goodbye") -> None: """Signal conversation completion. @@ -444,6 +448,7 @@ def _on_user_speaks(self, response: str) -> None: Args: response: The text that the simulated user said """ + current_record_id.set(self._record_id) self._reset_keepalive_counter() logger.info(f"🎭 User (ElevenLabs): {response}") @@ -462,6 +467,7 @@ def _on_user_response_correction(self, original: str, corrected: str) -> None: original: Original response corrected: Corrected response """ + current_record_id.set(self._record_id) logger.debug(f"User response corrected: {original} -> {corrected}") self.event_logger.log_event( @@ -480,6 +486,7 @@ def _on_assistant_speaks(self, transcript: str) -> None: Args: transcript: The text that the assistant said """ + current_record_id.set(self._record_id) self._reset_keepalive_counter() logger.info(f"🤖 Assistant: {transcript}") diff --git a/tests/unit/assistant/test_turn_config.py b/tests/unit/assistant/test_turn_config.py index 8939a4a8..72e135a6 100644 --- a/tests/unit/assistant/test_turn_config.py +++ b/tests/unit/assistant/test_turn_config.py @@ -185,6 +185,26 @@ def test_turn_analyzer_with_stop_secs(self): assert isinstance(passed_params, SmartTurnParams) assert passed_params.stop_secs == 0.8 + def test_turn_analyzer_smart_turn_stop_secs_via_strategy_params(self): + """smart_turn_stop_secs in strategy_params takes precedence over the function argument.""" + from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams + + mock_analyzer = MagicMock() + with patch( + "eva.assistant.pipeline.turn_config.LocalSmartTurnAnalyzerV3", + return_value=mock_analyzer, + ) as mock_cls: + create_turn_stop_strategy( + "turn_analyzer", + {"smart_turn_stop_secs": 1.5}, + smart_turn_stop_secs=0.8, + ) + + call_args = mock_cls.call_args + passed_params = call_args.kwargs["params"] + assert isinstance(passed_params, SmartTurnParams) + assert passed_params.stop_secs == 1.5 + def test_external_stop_strategy(self): """'external' returns ExternalUserTurnStopStrategy.""" result = create_turn_stop_strategy("external", {})