Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 100 additions & 11 deletions src/eva/assistant/base_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
See docs/assistant_server_contract.md for the full specification.
"""

import asyncio
import json
import wave
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -97,13 +98,73 @@ async def start(self) -> None:
"""
...

async def stop(self) -> asyncio.Task | None:
"""Stop the server: shut down framework, extract audio, save outputs.

Concrete template method — subclasses implement _shutdown() instead of stop().

Sequence:
1. _shutdown(): framework-specific teardown (server stop, task cancellation)
2. Auto-compute mixed audio from tracks if not already populated
3. Extract and clear audio buffers so the caller can release its concurrency
slot while audio hits disk
4. save_outputs(): persist audit log, transcript, scenario DBs
5. Return a deferred asyncio.Task for audio disk writes

Returns:
asyncio.Task that completes when audio files are written, or None if
no audio was recorded.
"""
await self._shutdown()

# Auto-compute mixed audio from tracks if not already populated (S2S servers
# populate user/assistant tracks but not the mixed buffer directly).
if not self._audio_buffer:
if self.user_audio_buffer and self.assistant_audio_buffer:
diff_bytes = abs(len(self.user_audio_buffer) - len(self.assistant_audio_buffer))
diff_ms = diff_bytes / (2 * self._audio_sample_rate) * 1000
if diff_ms > 500:
logger.warning(
f"Audio buffer length mismatch: user={len(self.user_audio_buffer)} "
f"assistant={len(self.assistant_audio_buffer)} "
f"diff={diff_ms:.0f}ms — mixed recording may be temporally skewed"
)
from eva.assistant.audio_bridge import pcm16_mix # lazy: avoids circular import at module load

self._audio_buffer = bytearray(
pcm16_mix(bytes(self.user_audio_buffer), bytes(self.assistant_audio_buffer))
)
elif self.user_audio_buffer:
self._audio_buffer = bytearray(self.user_audio_buffer)
elif self.assistant_audio_buffer:
self._audio_buffer = bytearray(self.assistant_audio_buffer)

# Extract bytes and clear in-memory buffers so the caller can release its
# concurrency slot while audio writes happen in a background thread.
mixed_audio = bytes(self._audio_buffer)
user_audio = bytes(self.user_audio_buffer)
assistant_audio = bytes(self.assistant_audio_buffer)
sample_rate = self._audio_sample_rate
self._audio_buffer.clear()
self.user_audio_buffer.clear()
self.assistant_audio_buffer.clear()

await self.save_outputs()

if mixed_audio or user_audio or assistant_audio:
return asyncio.create_task(
asyncio.to_thread(self._save_audio_deferred, mixed_audio, user_audio, assistant_audio, sample_rate)
)
return None

@abstractmethod
async def stop(self) -> None:
"""Stop the server and save all outputs.
async def _shutdown(self) -> None:
"""Framework-specific shutdown: stop server, cancel tasks, etc.

Must:
1. Gracefully shut down the server
2. Call save_outputs() to persist all data
Called by stop() before audio buffer extraction. Implementations should:
1. Check / set the running flag
2. Stop the WebSocket server (set should_exit, await server task)
3. Cancel any pending framework tasks (pipeline, session, etc.)
"""
...

Expand Down Expand Up @@ -148,22 +209,29 @@ async def save_outputs(self) -> None:

Subclasses can override to add framework-specific outputs,
but must call super().save_outputs().

Note: audio files are NOT saved here — they are written by the deferred
asyncio.Task returned by stop() so the concurrency slot is freed first.
"""
# Save audit log
self.audit_log.save(self.output_dir / "audit_log.json")

# Save simplified transcript
transcript_path = self.output_dir / "transcript.jsonl"
self.audit_log.save_transcript_jsonl(transcript_path)

# Save audio recordings
self._save_audio()
# Save transcript (subclasses can override _save_transcript for custom logic)
self._save_transcript()

# Save scenario database states (REQUIRED for deterministic metrics)
self._save_scenario_dbs()

logger.info(f"Outputs saved to {self.output_dir}")

def _save_transcript(self) -> None:
"""Save transcript.jsonl from the audit log.

Subclasses can override to customize transcript handling (e.g. conditional
overwrite logic for S2S vs pipeline modes).
"""
self.audit_log.save_transcript_jsonl(self.output_dir / "transcript.jsonl")

def _save_audio(self) -> None:
"""Save accumulated audio buffers to WAV files.

Expand Down Expand Up @@ -216,6 +284,27 @@ def _save_audio(self) -> None:
1,
)

def _save_audio_deferred(
self,
mixed_audio: bytes,
user_audio: bytes,
assistant_audio: bytes,
sample_rate: int,
) -> None:
"""Write pre-extracted audio bytes to WAV files.

Designed to run in a thread pool via asyncio.to_thread so audio disk
writes happen outside the concurrency semaphore.
"""
if mixed_audio:
self._save_wav_file(mixed_audio, self.output_dir / "audio_mixed.wav", sample_rate, 1)
if user_audio:
self._save_wav_file(user_audio, self.output_dir / "audio_user.wav", sample_rate, 1)
if assistant_audio:
self._save_wav_file(assistant_audio, self.output_dir / "audio_assistant.wav", sample_rate, 1)
if mixed_audio or user_audio or assistant_audio:
logger.info(f"Saved audio files to {self.output_dir} ({len(mixed_audio)} bytes mixed)")

def _save_wav_file(self, audio_data: bytes, file_path: Path, sample_rate: int, num_channels: int) -> None:
"""Save raw 16-bit PCM audio data to a WAV file."""
try:
Expand Down
5 changes: 2 additions & 3 deletions src/eva/assistant/gemini_live_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ async def websocket_root(websocket: WebSocket):

logger.info(f"GeminiLive server started on ws://localhost:{self.port}")

async def stop(self) -> None:
"""Stop the server, save outputs."""
async def _shutdown(self) -> None:
"""Stop the GeminiLive server."""
if not self._running:
return
self._running = False
Expand All @@ -274,7 +274,6 @@ async def stop(self) -> None:
self._server = None
self._server_task = None

await self.save_outputs()
logger.info(f"GeminiLive server stopped on port {self.port}")

# ------------------------------------------------------------------
Expand Down
9 changes: 2 additions & 7 deletions src/eva/assistant/openai_realtime_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ async def websocket_root(websocket: WebSocket):

logger.info(f"OpenAI Realtime server started on ws://localhost:{self.port}")

async def stop(self) -> None:
"""Stop the server and save all outputs."""
async def _shutdown(self) -> None:
"""Stop the OpenAI Realtime server."""
if not self._running:
return

Expand All @@ -181,13 +181,8 @@ async def stop(self) -> None:
self._server = None
self._server_task = None

await self.save_outputs()
logger.info(f"OpenAI Realtime server stopped on port {self.port}")

async def save_outputs(self) -> None:
"""Save all outputs including mixed audio."""
await super().save_outputs()

async def _handle_session(self, websocket: WebSocket) -> None:
"""Handle a single WebSocket session.

Expand Down
29 changes: 17 additions & 12 deletions src/eva/assistant/pipecat_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,27 +181,26 @@ async def websocket_root(websocket: WebSocket):

logger.info(f"Assistant server started on ws://localhost:{self.port}")

async def stop(self) -> None:
"""Stop the server and save outputs."""
async def _shutdown(self) -> None:
"""Stop the Pipecat pipeline and uvicorn server."""
if not self._running:
return

self._running = False

# Stop the pipeline task
# Cancel pipeline task first so no more audio arrives before base stop()
# extracts the buffers.
if self._task:
await self._task.cancel()
self._task = None

# Stop the server gracefully
# Stop the uvicorn server gracefully.
if self._server:
self._server.should_exit = True
# Wait briefly for graceful shutdown, then cancel if needed
if self._server_task:
try:
await asyncio.wait_for(self._server_task, timeout=5.0)
except TimeoutError:
# Force cancellation if graceful shutdown times out
self._server_task.cancel()
try:
await self._server_task
Expand All @@ -212,9 +211,6 @@ async def stop(self) -> None:
self._server = None
self._server_task = None

# Save outputs
await self.save_outputs()

logger.info(f"Assistant server stopped on port {self.port}")

async def _handle_session(self, websocket) -> None:
Expand Down Expand Up @@ -719,8 +715,20 @@ def _current_iso_timestamp() -> str:
"""Return the current time as an ISO 8601 string with timezone."""
return time_now_iso8601()

def _save_transcript(self) -> None:
"""Pipecat-specific transcript handling.

For S2S mode, always rebuild from the audit log (correct ordering).
For pipeline mode, only write if not already written incrementally.
"""
transcript_path = self.output_dir / "transcript.jsonl"
if isinstance(self.pipeline_config, SpeechToSpeechConfig) or not transcript_path.exists():
self.audit_log.save_transcript_jsonl(transcript_path)

async def save_outputs(self) -> None:
"""Save all outputs, with pipecat-specific additions."""
await super().save_outputs()

# Save agent performance stats (pipecat-specific: AgenticSystem tracking)
if self.agentic_system:
try:
Expand All @@ -729,9 +737,6 @@ async def save_outputs(self) -> None:
except Exception as e:
logger.error(f"Error saving agent perf stats: {e}", exc_info=True)

# Call base class to save audit_log, audio, scenario DBs, latencies
await super().save_outputs()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicating most of super().save_outputs(), instead of calling it, worries me about them going out of sync. Have you considered refactoring so that you can keep calling super().save_outputs()?



async def override__maybe_trigger_user_turn_stopped(self):
"""Trigger user turn stopped if conditions are met.
Expand Down
63 changes: 36 additions & 27 deletions src/eva/metrics/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ async def run(self, contexts: dict[str, Any] | None = None) -> MetricsRunResult:
targeted_ids = {rid for rid, _ in targeted}

# Run targeted records concurrently; LiteLLM limits concurrent API calls.
tasks = [self._run_and_save_record(rid, rdir) for rid, rdir in targeted]
tasks = [self.run_and_save_record(rid, rdir) for rid, rdir in targeted]
results = await asyncio.gather(*tasks, return_exceptions=True)

for (record_id, _), result in zip(targeted, results):
Expand Down Expand Up @@ -323,7 +323,7 @@ async def run(self, contexts: dict[str, Any] | None = None) -> MetricsRunResult:
metric_failures=metric_failures,
)

async def _run_and_save_record(self, record_id: str, record_dir: Path) -> RecordMetrics | None:
async def run_and_save_record(self, record_id: str, record_dir: Path) -> RecordMetrics | None:
"""Run metrics for a record and save results, merging with existing metrics.

Skips computation when possible:
Expand Down Expand Up @@ -418,7 +418,7 @@ async def _run_record(self, record_id: str, record_dir: Path) -> RecordMetrics:
logger.debug(f"Computing metrics for record: {record_id}")

# Load conversation data
context = self._load_context(record_id, record_dir)
context = await self._load_context(record_id, record_dir)

# Determine which metrics to run for this record
metrics_to_run = self.metrics
Expand Down Expand Up @@ -480,7 +480,7 @@ async def compute_metric(metric: BaseMetric) -> tuple[str, MetricScore]:

return RecordMetrics(record_id=record_id, context=context_dict, metrics=metric_scores)

def _load_context(self, record_id: str, record_dir: Path) -> MetricContext:
async def _load_context(self, record_id: str, record_dir: Path) -> MetricContext:
"""Load all data needed for metric computation."""
# Strip _trial_N suffix to get base record ID for dataset lookup.
base_record_id, _ = parse_trial_record_id(record_id)
Expand All @@ -492,17 +492,41 @@ def _load_context(self, record_id: str, record_dir: Path) -> MetricContext:

gt = record.ground_truth

# Load conversation result
# Load conversation result and scenario databases in parallel (non-blocking I/O)
result_path = record_dir / "result.json"
result_data = {}
if result_path.exists():
result_data = json.loads(result_path.read_text())
initial_db_path = record_dir / "initial_scenario_db.json"
final_db_path = record_dir / "final_scenario_db.json"

if not initial_db_path.exists():
raise FileNotFoundError(
f"Initial scenario database not found at {initial_db_path}. "
"This is required for deterministic task completion metrics."
)
if not final_db_path.exists():
raise FileNotFoundError(
f"Final scenario database not found at {final_db_path}. "
"This is required for deterministic task completion metrics."
)

async def _read_optional(path: Path) -> str:
return await asyncio.to_thread(path.read_text) if path.exists() else "{}"

result_text, initial_db_text, final_db_text = await asyncio.gather(
_read_optional(result_path),
asyncio.to_thread(initial_db_path.read_text),
asyncio.to_thread(final_db_path.read_text),
)

result_data = json.loads(result_text)

# Create ConversationResult object
result = ConversationResult(**result_data)

metrics_context = self._context_cache.get(record_id) or self.metrics_processor.process_record(
result, record_dir, pipeline_type=self._pipeline_type
# Use postprocessor to process logs and create enriched context.
# Check cache first (populated by process_records() pre-pass); fall back to
# processing in a thread to avoid blocking the event loop.
metrics_context = self._context_cache.get(record_id) or await asyncio.to_thread(
self.metrics_processor.process_record, result, record_dir, pipeline_type=self._pipeline_type
)

# Get agent instructions and tools from config
Expand All @@ -515,23 +539,8 @@ def _load_context(self, record_id: str, record_dir: Path) -> MetricContext:

user_persona = record.user_config["user_persona"]

# Load scenario database states (REQUIRED for deterministic metrics)
initial_db_path = record_dir / "initial_scenario_db.json"
final_db_path = record_dir / "final_scenario_db.json"

if not initial_db_path.exists():
raise FileNotFoundError(
f"Initial scenario database not found at {initial_db_path}. "
"This is required for deterministic task completion metrics."
)
if not final_db_path.exists():
raise FileNotFoundError(
f"Final scenario database not found at {final_db_path}. "
"This is required for deterministic task completion metrics."
)

initial_scenario_db = json.loads(initial_db_path.read_text())
final_scenario_db = json.loads(final_db_path.read_text())
initial_scenario_db = json.loads(initial_db_text)
final_scenario_db = json.loads(final_db_text)

# Get hashes from result or compute if needed
initial_scenario_db_hash = getattr(result, "initial_scenario_db_hash", None) or get_dict_hash(
Expand Down
Loading
Loading