# TTS WebSocket Evaluation

Runs end-to-end measurements for:
- Latency (client-side optimization metric: first real text sent â†’ first audio received)
- Audio quality sanity checks (RMS, peak, clipping, silence)
- Optional alignment accuracy vs WhisperX (if installed)

Prereqs:
- Server running at `ws://localhost:8000/tts`
- `websockets`, `numpy` installed (already in requirements.txt)
- For alignment eval: `whisperx`, `torch`, `torchaudio` in the same environment.


In [None]:
import asyncio, base64, json, time, math, sys
from typing import List, Dict, Tuple
import numpy as np
import websockets

# Optional: WhisperX aligner
try:
    from app.aligner import WhisperXAligner
    HAVE_WHISPERX = True
except Exception:
    HAVE_WHISPERX = False


In [None]:
def pcm_from_b64(audio_b64: str) -> np.ndarray:
    pcm = base64.b64decode(audio_b64)
    return np.frombuffer(pcm, dtype=np.int16).astype(np.float32) / 32768.0

async def run_session(text: str, uri: str = "ws://localhost:8000/tts", chunk_size: int = 512, delay: float = 0.0):
    # Streams text, returns audio array, sample rate, alignments, and timing metrics.
    # Alignment timestamps are absolute (ms from start of session).
    pcm_chunks: List[np.ndarray] = []
    align_chars: List[str] = []
    align_starts: List[float] = []
    align_durs: List[float] = []

    sr = 44100
    samples_so_far = 0
    t_first_send = None
    t_first_audio = None
    t_start = None
    t_last_audio = None

    async with websockets.connect(uri, max_size=None) as ws:
        await ws.send(json.dumps({"text": " ", "flush": False}))

        for i in range(0, len(text), chunk_size):
            if t_start is None:
                t_start = time.perf_counter()
            await ws.send(json.dumps({"text": text[i:i+chunk_size], "flush": False}))
            if t_first_send is None:
                t_first_send = time.perf_counter()
            await asyncio.sleep(delay)

        await ws.send(json.dumps({"text": "", "flush": True}))
        await ws.send(json.dumps({"text": "", "flush": False}))

        try:
            while True:
                msg = await ws.recv()
                if isinstance(msg, bytes):
                    msg = msg.decode("utf-8")
                payload = json.loads(msg)
                audio_b64 = payload.get("audio", "")
                if not audio_b64:
                    continue
                audio = pcm_from_b64(audio_b64)
                pcm_chunks.append(audio)
                samples = audio.shape[0]
                chunk_start_ms = (samples_so_far / float(sr)) * 1000.0
                samples_so_far += samples

                aln = payload.get("alignment", {}) or {}
                chars = aln.get("chars", [])
                starts = aln.get("char_start_times_ms", [])
                durs = aln.get("char_durations_ms", [])
                for ch, s, d in zip(chars, starts, durs):
                    align_chars.append(ch)
                    align_starts.append(chunk_start_ms + float(s))
                    align_durs.append(float(d))

                if t_first_audio is None:
                    t_first_audio = time.perf_counter()
                t_last_audio = time.perf_counter()
        except websockets.ConnectionClosed:
            pass

    audio_out = np.concatenate(pcm_chunks) if pcm_chunks else np.zeros(0, dtype=np.float32)
    return {
        "audio": audio_out,
        "sr": sr,
        "alignment": {
            "chars": align_chars,
            "char_start_times_ms": align_starts,
            "char_durations_ms": align_durs,
        },
        "timing": {
            "t_first_send": t_first_send,
            "t_first_audio": t_first_audio,
            "t_last_audio": t_last_audio,
            "t_start": t_start,
        },
    }


In [None]:
def audio_quality(audio: np.ndarray) -> Dict[str, float]:
    if audio.size == 0:
        return {"rms": 0.0, "peak": 0.0, "clipping_ratio": 0.0, "silence_ratio": 1.0}
    rms = float(np.sqrt(np.mean(audio ** 2)))
    peak = float(np.max(np.abs(audio)))
    clipping = float(np.mean(np.abs(audio) >= 0.999))
    silence = float(np.mean(np.abs(audio) < 1e-4))
    return {
        "rms": rms,
        "peak": peak,
        "clipping_ratio": clipping,
        "silence_ratio": silence,
    }


In [None]:
import re, difflib

def word_spans(text: str):
    return [(m.start(), m.end()) for m in re.finditer(r"\S+", text)]

def char_to_word_map(spans, length):
    mapping = [None] * length
    for idx, (s, e) in enumerate(spans):
        for i in range(s, e):
            mapping[i] = idx
    return mapping

def align_words(aln: Dict[str, List], text: str):
    chars = aln.get("chars", [])
    starts = aln.get("char_start_times_ms", [])
    durs = aln.get("char_durations_ms", [])
    spans = word_spans(text)
    text_chars = list(text)
    c2w = char_to_word_map(spans, len(text_chars))
    matcher = difflib.SequenceMatcher(None, chars, text_chars)
    word_times = {}
    for block in matcher.get_matching_blocks():
        a0, b0, size = block
        for j in range(size):
            si = a0 + j
            ti = b0 + j
            widx = c2w[ti] if 0 <= ti < len(c2w) else None
            if widx is None:
                continue
            start = float(starts[si])
            end = start + float(durs[si])
            if widx not in word_times:
                word_times[widx] = (start, end)
            else:
                s0, e0 = word_times[widx]
                word_times[widx] = (min(s0, start), max(e0, end))
    return word_times


def compare_alignment_words(server_aln, ref_aln, text):
    spans = word_spans(text)
    sw = align_words(server_aln, text)
    rw = align_words(ref_aln, text)
    errors = []
    matched = 0
    for idx in range(len(spans)):
        if idx in sw and idx in rw:
            s_start, _ = sw[idx]
            r_start, _ = rw[idx]
            errors.append(s_start - r_start)
            matched += 1
    stats = {
        "words_total": len(spans),
        "matched": matched,
        "unmatched": len(spans) - matched,
    }
    if errors:
        errors = np.array(errors, dtype=np.float32)
        pct = lambda p: float(np.percentile(errors, p))
        stats.update(
            {
                "mean_error_ms": float(errors.mean()),
                "median_error_ms": pct(50),
                "p90_error_ms": pct(90),
                "p99_error_ms": pct(99),
                "max_abs_error_ms": float(np.max(np.abs(errors))),
            }
        )
    return stats


In [None]:
# Run evaluation on sample_text.txt (server must be running)
text_path = "sample_text.txt"
with open(text_path, "r", encoding="utf-8") as f:
    sample_text = f.read()

result = asyncio.run(run_session(sample_text, chunk_size=512, delay=0.0))
audio = result["audio"]
sr = result["sr"]
alignment = result["alignment"]
timing = result["timing"]

# Latency metrics (client-side optimization metric)
ttft = None
if timing["t_first_audio"] and timing["t_first_send"]:
    ttft = timing["t_first_audio"] - timing["t_first_send"]
total = None
if timing["t_last_audio"] and timing["t_first_send"]:
    total = timing["t_last_audio"] - timing["t_first_send"]

print({
    "ttft_s": ttft,
    "total_s": total,
    "tokens": len(sample_text),
    "tokens_per_s": (len(sample_text) / total) if total else None,
})

print("Audio quality:", audio_quality(audio))


In [None]:
if HAVE_WHISPERX:
    print("Running WhisperX alignment for reference...")
    ax = WhisperXAligner()
    ref_aln = ax.align(audio, sr, sample_text) or {}
    stats = compare_alignment_words(alignment, ref_aln, sample_text)
    print("Alignment stats (word-level vs WhisperX):")
    for k, v in stats.items():
        print(f"  {k}: {v}")
else:
    print("WhisperX not installed; skipping alignment comparison.")
