### This implementation provides real-time text-to-speech streaming with character alignment
Design goal: Create a production-ready TTS service that can handle various edge cases

In [1]:
# Install kokoro
!pip install -q kokoro soundfile
!git clone https://huggingface.co/hexgrad/Kokoro-82M

# Install espeak, used for out-of-dictionary fallback
!apt-get -qq -y install espeak-ng > /dev/null 2>&1


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m
Cloning into 'Kokoro-82M'...
remote: Enumerating objects: 421, done.[K
remote: Counting objects: 100% (30/30), done.[K
remote: Compressing objects: 100% (30/30), done.[K
remote: Total 421 (delta 18), reused 0 (delta 0), pack-reused 391 (from 1)[K
Receiving objects: 100% (421/421), 1.83 MiB | 9.11 MiB/s, done.
Resolving deltas: 100% (239/239), done.
Filtering content: 100% (61/61), 344.32 MiB | 26.70 MiB/s, done.


In [2]:
import asyncio
import websockets
import json
import base64
import numpy as np
import time
import re
import unicodedata
import threading
import concurrent.futures
from typing import List, Dict, Optional, Tuple, AsyncGenerator, Union
from dataclasses import dataclass
import logging
from io import BytesIO
import soundfile as sf
import types
from collections.abc import Iterator

import nest_asyncio
nest_asyncio.apply()

In [3]:
# Configure logging for debugging and monitoring production issues
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Import Kokoro TTS (assuming it's installed)
try:
    from kokoro import KPipeline
except ImportError:
    logger.error("Kokoro TTS not found. Please install it first.")
    raise

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [4]:
# Character weighting for alignment heuristics
# These weights are based on phonetic duration research - vowels typically last longer
# than consonants, and punctuation creates natural pauses in speech
_VOWELS = set(list("aeiouAEIOU"))
_PUNCT = set(list(",.;:!?\"'()[]{}"))
_SPACE = set([" ", "\t", "\n"])

def _char_weight(c: str) -> float:
    """Calculate relative duration weight for a character.

    This heuristic approach provides reasonable character timing without
    requiring expensive neural alignment models. Based on linguistic research:
    - Vowels are typically held longer in speech
    - Punctuation creates natural pauses
    - Consonants have baseline duration
    """
    if c in _SPACE: return 0.55
    if c in _PUNCT: return 0.65
    if c in _VOWELS: return 1.25
    return 1.0

def alignment_for_text(text: str, total_ms: float) -> Dict[str, List[float]]:
    """Generate heuristic character timing for given text.

    This function distributes the total audio duration across characters
    based on their phonetic weights. While not as accurate as neural models,
    it provides good enough timing for most applications and is much faster.
    
    Args:
        text: The text that was synthesized
        total_ms: Total duration of the audio in milliseconds
        
    Returns:
        Dictionary with character-level timing information
    """

    if not text:
        return {"chars": [], "char_start_times_ms": [], "char_durations_ms": []}

    chars = list(text)
    # Calculate weights for each character based on phonetic properties
    weights = np.array([_char_weight(c) for c in chars], dtype=np.float64)
    W = float(np.sum(weights))

    # Handle edge case where all weights are zero (shouldn't happen in practice)
    if W <= 0:
        weights = np.ones(len(chars), dtype=np.float64)
        W = float(len(chars))

    # Distribute total duration proportionally based on weights
    durations = (weights / W) * float(total_ms)
    # Calculate start times by cumulative sum (each character starts when previous ends)
    starts = np.concatenate([[0.0], np.cumsum(durations)[:-1]])

    return {
        "chars": chars,
        "char_start_times_ms": np.round(starts, 3).tolist(),
        "char_durations_ms": np.round(durations, 3).tolist(),
    }

In [5]:
def _safe_to_numpy(x, target_dtype=np.float32) -> np.ndarray:
    """Safely convert various types to numpy array.
    
    This function handles the complexity of different TTS output formats.
    Kokoro TTS (and other engines) may return various formats:
    - Raw numpy arrays
    - PyTorch tensors (need CPU conversion)
    - Bytes (PCM data)
    - Lists of chunks
    - Multi-dimensional arrays (stereo -> mono conversion)
    
    The robust handling prevents crashes when TTS engines change their output format.
    """
    if x is None:
        return np.array([], dtype=target_dtype)

    # Handle PyTorch tensors - detach from compute graph and move to CPU
    if hasattr(x, "detach"):
        x = x.detach()
    if hasattr(x, "cpu"):
        x = x.cpu()
    if hasattr(x, "numpy"):
        x = x.numpy()

    # Handle raw PCM bytes (common in audio processing)
    if isinstance(x, (bytes, bytearray, memoryview)):
        # Convert 16-bit PCM to float32 in [-1, 1] range
        return np.frombuffer(x, dtype=np.int16).astype(target_dtype) / 32768.0

    # Convert to numpy with error handling
    try:
        arr = np.asarray(x, dtype=target_dtype)
    except (ValueError, TypeError):
        # If conversion fails, try to handle as flatten nested sequences
        if hasattr(x, '__iter__') and not isinstance(x, (str, bytes)):
            try:
                flat_list = []
                for item in x:
                    sub_arr = _safe_to_numpy(item, target_dtype)
                    if sub_arr.size > 0:
                        flat_list.extend(sub_arr.flatten())
                return np.array(flat_list, dtype=target_dtype) if flat_list else np.array([], dtype=target_dtype)
            except:
                return np.array([], dtype=target_dtype)
        else:
            return np.array([], dtype=target_dtype)

    # Ensure we have a valid array
    if not isinstance(arr, np.ndarray):
        return np.array([], dtype=target_dtype)

    # Handle multi-dimensional arrays (make mono)
    if arr.ndim > 1:
        if arr.ndim == 2 and arr.shape[0] in (1, 2) and arr.shape[1] > arr.shape[0]:
            # Likely (channels, samples) format
            arr = arr.mean(axis=0, dtype=target_dtype)
        else:
            # Average across last dimension
            arr = arr.mean(axis=-1, dtype=target_dtype)

    # Clean up any NaN/Inf values
    arr = np.nan_to_num(arr, nan=0.0, posinf=1.0, neginf=-1.0)

    return arr.astype(target_dtype, copy=False)

In [6]:
def _process_kokoro_output(output) -> np.ndarray:
    """Process Kokoro TTS output into a clean numpy array.
    
    Kokoro TTS may return different output formats:
    - Tuple: (audio_data, sample_rate) 
    - Dict: {'audio': data, 'sr': sample_rate, ...}
    - Generator: streaming chunks of audio
    - Raw array: direct audio data
    
    This function normalizes all these formats into a single numpy array.
    The robust handling prevents service crashes when the TTS engine updates.
    """
    if output is None:
        return np.array([], dtype=np.float32)

    # Handle tuple outputs like (audio, sample_rate)
    if isinstance(output, tuple):
        # Could be (audio, sr) or (audio, other_info)
        audio_data = output[0]
    elif isinstance(output, dict):
        # Look for common audio keys
        for key in ['audio', 'samples', 'waveform', 'data']:
            if key in output:
                audio_data = output[key]
                break
        else:
            return np.array([], dtype=np.float32)
    else:
        audio_data = output

    # Handle generators and iterators
    if isinstance(audio_data, (types.GeneratorType, Iterator)):
        chunks = []
        try:
            for chunk in audio_data:
                chunk_array = _safe_to_numpy(chunk)
                if chunk_array.size > 0:
                    chunks.append(chunk_array)
        except Exception as e:
            logger.warning(f"Error processing generator output: {e}")

        if chunks:
            # Ensure all chunks are 1D before concatenating
            flat_chunks = []
            for chunk in chunks:
                flat_chunk = chunk.flatten()
                if flat_chunk.size > 0:
                    flat_chunks.append(flat_chunk)

            return np.concatenate(flat_chunks) if flat_chunks else np.array([], dtype=np.float32)
        else:
            return np.array([], dtype=np.float32)

    # Handle lists and other sequences
    if isinstance(audio_data, (list, tuple)):
        if len(audio_data) == 0:
            return np.array([], dtype=np.float32)

        # Try to process as list of chunks
        chunks = []
        for item in audio_data:
            chunk_array = _safe_to_numpy(item)
            if chunk_array.size > 0:
                chunks.append(chunk_array.flatten())

        if chunks:
            return np.concatenate(chunks)
        else:
            # Fallback: try to convert the whole list
            return _safe_to_numpy(audio_data)

    # Direct conversion
    return _safe_to_numpy(audio_data)

In [7]:
def audio_to_pcm16_base64(samples: np.ndarray, sr: int) -> str:
    """Convert numpy audio to 44.1kHz 16-bit PCM Base64.
    
    This standardizes all audio to the format expected by web browsers:
    - 44.1kHz sample rate (CD quality, widely supported)
    - 16-bit PCM (good quality/size balance)
    - Base64 encoding for JSON transport
    
    The resampling uses linear interpolation which is fast and adequate
    for TTS applications (higher quality methods like sinc would be slower).
    """
    if samples is None or samples.size == 0:
        return ""

    # Ensure 1D array
    samples = samples.flatten()

    # Resample to 44.1kHz if needed
    target_sr = 44100 # Standard web audio sample rate
    if sr != target_sr and samples.size > 0:
        ratio = target_sr / float(sr)
        target_len = max(1, int(round(len(samples) * ratio)))

        if len(samples) > 1:
            # Simple linear interpolation for resampling
            # More sophisticated methods (like sinc) would be slower
            old_indices = np.linspace(0, len(samples) - 1, len(samples))
            new_indices = np.linspace(0, len(samples) - 1, target_len)
            samples = np.interp(new_indices, old_indices, samples)

    # Convert to 16-bit PCM (clip to prevent overflow)
    samples = np.clip(samples, -1.0, 1.0) # Ensure samples are in valid range
    pcm16 = (samples * 32767.0).astype(np.int16).tobytes()
    return base64.b64encode(pcm16).decode('ascii')

In [8]:
class MathNotationProcessor:
    """Processes mathematical notation for TTS.
    
    This class converts LaTeX mathematical expressions into speakable text.
    Mathematical notation is common in educational and scientific content,
    but TTS engines can't pronounce symbols like ∑ or ∫ correctly.
    
    The processor handles:
    - Common operators (times, divided by, etc.)
    - Greek letters (alpha, beta, etc.) 
    - Fractions, square roots, exponents
    - Proper handling of minus signs (unary vs binary)
    """

    def __init__(self, unary_minus_word="negative", binary_minus_word="minus"):
        """Initialize with customizable minus sign pronunciation.
        
        Different contexts require different minus sign pronunciations:
        - Unary: "-5" should be "negative five"
        - Binary: "7 - 3" should be "seven minus three"
        """
        self.unary_minus_word = unary_minus_word
        self.binary_minus_word = binary_minus_word
        self.math_symbols = {
            r'\\times': ' times ',
            r'\\cdot': ' dot ',
            r'\\div': ' divided by ',
            r'\\pm': ' plus or minus ',
            r'\\mp': ' minus or plus ',
            r'\\leq': ' less than or equal to ',
            r'\\geq': ' greater than or equal to ',
            r'\\neq': ' not equal to ',
            r'\\approx': ' approximately equal to ',
            r'\\equiv': ' equivalent to ',
            r'\\infty': ' infinity ',
            r'\\sum': ' sum ',
            r'\\prod': ' product ',
            r'\\int': ' integral ',
            r'\\partial': ' partial ',
            r'\\nabla': ' nabla ',
            r'\\alpha': ' alpha ',
            r'\\beta': ' beta ',
            r'\\gamma': ' gamma ',
            r'\\delta': ' delta ',
            r'\\epsilon': ' epsilon ',
            r'\\theta': ' theta ',
            r'\\lambda': ' lambda ',
            r'\\mu': ' mu ',
            r'\\pi': ' pi ',
            r'\\sigma': ' sigma ',
            r'\\phi': ' phi ',
            r'\\psi': ' psi ',
            r'\\omega': ' omega ',
            r'\\sqrt': ' square root ',
        }

        # Precompiled regex patterns for better performance
        # Math expressions can be inline \(...\) or display \[...\]
        self._re_inline = re.compile(r'\\\((.*?)\\\)', re.DOTALL)
        self._re_display = re.compile(r'\\\[(.*?)\\\]', re.DOTALL)
        self._re_left_right = re.compile(r'\\left|\\right')
        self._re_frac = re.compile(r'\\frac\{([^{}]+)\}\{([^{}]+)\}')
        self._re_sqrt = re.compile(r'\\sqrt\{([^{}]+)\}')
        self._re_pow_simple = re.compile(r'\^(\w)')
        self._re_pow_braced = re.compile(r'\^\{([^}]+)\}')
        self._re_sub_simple = re.compile(r'_(\w)')
        self._re_sub_braced = re.compile(r'_\{([^}]+)\}')
        self._re_unary_minus = re.compile(r'(^|[\(\[\{\s])-(?=\s*[\w(])')
        self._re_binary_minus = re.compile(r'\s-\s|(?<=\w)-(?=\s)|(?<=\s)-(?=\w)')
        self._re_num_letter = re.compile(r'(?<=\d)(?=[A-Za-z])')
        self._re_letter_num = re.compile(r'(?<=[A-Za-z])(?=\d)')
        self._re_multi_space = re.compile(r'\s+')

    def process_math_text(self, text: str) -> str:
        """Process mathematical notation in text.
        
        Finds LaTeX expressions and converts them to speakable text.
        Handles both inline \(...\) and display \[...\] math modes.
        """
        if not text:
            return text

        # Process inline math: \(expression\)
        text = self._re_inline.sub(self._process_math_expression, text)
        # Process display math: \[expression\]
        text = self._re_display.sub(self._process_math_expression, text)
        return text

    def _process_math_expression(self, match) -> str:
        """Process a single math expression.
        
        Converts LaTeX math syntax to natural language.
        Order of operations matters - process complex structures first,
        then simple substitutions.
        """
        expr = match.group(1)

        # Remove sizing delimiters
        expr = self._re_left_right.sub('', expr)

        # Replace symbols
        for symbol, replacement in self.math_symbols.items():
            expr = re.sub(symbol, replacement, expr)

        # Handle fractions (with nesting support)
        prev = None
        while prev != expr:
            prev = expr
            expr = self._re_frac.sub(r'\1 over \2', expr)

        # Handle square roots
        expr = self._re_sqrt.sub(r'square root of \1', expr)

        # Handle exponents
        expr = self._re_pow_braced.sub(r' to the power of \1', expr)
        expr = self._re_pow_simple.sub(r' to the power of \1', expr)

        # Handle subscripts
        expr = self._re_sub_braced.sub(r' sub \1', expr)
        expr = self._re_sub_simple.sub(r' sub \1', expr)

        # Handle minus signs
        expr = self._re_unary_minus.sub(rf'\1{self.unary_minus_word} ', expr)
        expr = self._re_binary_minus.sub(f' {self.binary_minus_word} ', expr)

        # Add spacing for readability
        expr = self._re_num_letter.sub(' ', expr)
        expr = self._re_letter_num.sub(' ', expr)

        # Clean up spacing
        expr = self._re_multi_space.sub(' ', expr).strip()
        return f" {expr} "

In [9]:
def sanitize_text(text: str) -> str:
    """Sanitize text for TTS processing.
    
    TTS engines work best with clean, normalized text.
    This function handles common issues:
    - Unicode normalization (different encodings of same character)
    - Smart quotes -> ASCII quotes  
    - Various dash types -> ASCII dash
    - Non-breaking spaces -> regular spaces
    - Mathematical symbols -> words
    
    The goal is consistent, speakable text that won't confuse the TTS engine.
    """
    if not text:
        return text

    # Unicode normalization - convert visually similar characters to standard form
    # NFKC handles things like smart quotes, em dashes, etc.
    text = unicodedata.normalize('NFKC', text)

    # Replace problematic unicode characters with ASCII equivalents
    # These mappings handle the most common text encoding issues
    replacements = {
        0x201C: '"', 0x201D: '"', 0x2018: "'", 0x2019: "'",
        0x2013: '-', 0x2014: '-', 0x2212: '-',
        0x00A0: ' ', 0x202F: ' ', 0x2007: ' ', 0x2009: ' ',
        0x200A: ' ', 0x2000: ' ', 0x2001: ' ', 0x2002: ' ',
        0x2003: ' ', 0x2004: ' ', 0x2005: ' ', 0x2006: ' ',
        0x2028: ' ', 0x2029: ' '
    }
    text = text.translate(replacements)

    # Collapse whitespace
    text = re.sub(r'[\r\n\t]+', ' ', text)

    # Replace math symbols with words
    text = text.translate(str.maketrans({
        '/': ' slash ', '=': ' equals ', '+': ' plus ',
        '*': ' times ', '%': ' percent ', '&': ' and ',
        '|': ' or ', '\\': ' '
    }))

    # Clean up spacing
    text = re.sub(r'\s+', ' ', text).strip()
    return text

In [10]:
@dataclass
class TTSChunk:
    """Represents a chunk of TTS output.
    
    Each chunk contains:
    - audio_b64: Base64-encoded PCM audio data
    - alignment: Character-level timing information
    - text: The original text that generated this audio
    
    This structure allows the client to synchronize audio playback
    with text highlighting for better user experience.
    """
    audio_b64: str
    alignment: Dict[str, List[float]]
    text: str

In [11]:
class WebSocketTTSService:
    """Main WebSocket TTS Service.
    
    This class implements a streaming TTS service compatible with ElevenLabs API.
    Key design decisions:
    
    1. Streaming architecture: Process text as it arrives, don't wait for complete input
    2. Sentence boundary detection: Generate audio at natural speech boundaries  
    3. Error resilience: Continue operation even if individual chunks fail
    4. Memory efficiency: Process audio in chunks to handle long texts
    5. Protocol compatibility: Follow ElevenLabs WebSocket message format
    """

    def __init__(self, host="localhost", port=8765, voice="af_heart", use_ngrok=False):
        """Initialize the TTS service.
        
        Args:
            host: Server bind address (0.0.0.0 allows external connections)
            port: Server port (8765 is common for WebSocket services)
            voice: Voice model identifier for Kokoro TTS
        """
        self.host = host
        self.port = port
        self.voice = voice
        self.use_ngrok = use_ngrok
        self.ngrok_tunnel = None
        self.math_processor = MathNotationProcessor()

        # Initialize Kokoro pipeline
        logger.info("Initializing Kokoro TTS pipeline...")
        try:
            self.pipeline = KPipeline(lang_code='a')
            self.default_sr = getattr(self.pipeline, "sample_rate", 24000)
            logger.info("TTS pipeline ready!")
        except Exception as e:
            logger.error(f"Failed to initialize Kokoro pipeline: {e}")
            raise

    async def handle_client(self, websocket):
        """Handle a WebSocket client connection.
        
        This implements the ElevenLabs-compatible protocol:
        1. Client sends " " (space) to initiate
        2. Client sends text chunks with flush=false
        3. Server responds with audio chunks
        4. Client sends "" with flush=false to disconnect
        
        The text buffer accumulates partial sentences and generates audio
        at natural boundaries (sentence endings) for optimal speech quality.
        """
        client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
        logger.info(f"Client connected: {client_id}")

        text_buffer = ""

        try:
            async for message in websocket:
                try:
                    data = json.loads(message)
                    text = data.get("text", "")
                    flush = data.get("flush", False)

                    logger.debug(f"Received: text='{text}', flush={flush}")

                    # Handle protocol: first message has single space, last has empty string
                    if text == " " and not text_buffer:
                        # First message - just acknowledge
                        continue
                    elif text == "" and not flush:
                        # Final message - close connection
                        logger.info(f"Client {client_id} requested disconnect")
                        break

                    # Add text to buffer
                    if text and text != " ":
                        processed_text = self.math_processor.process_math_text(text)
                        processed_text = sanitize_text(processed_text)
                        text_buffer += processed_text

                    # Check if we should generate audio
                    should_generate = flush or self._should_generate_audio(text_buffer)

                    if should_generate and text_buffer.strip():
                        # Get text to generate
                        if flush:
                            generate_text = text_buffer.strip()
                            text_buffer = ""
                        else:
                            generate_text = self._extract_complete_sentences(text_buffer)
                            text_buffer = text_buffer[len(generate_text):].strip()

                        if generate_text:
                            # Generate and stream audio
                            async for chunk in self._generate_audio_stream(generate_text):
                                if chunk.audio_b64:  # Only send if we have valid audio
                                    response = {
                                        "audio": chunk.audio_b64,
                                        "alignment": chunk.alignment
                                    }
                                    await websocket.send(json.dumps(response))

                except json.JSONDecodeError:
                    logger.error(f"Invalid JSON from client {client_id}")
                except Exception as e:
                    logger.error(f"Error processing message from {client_id}: {e}")

        except websockets.exceptions.ConnectionClosed:
            logger.info(f"Client {client_id} disconnected")
        except Exception as e:
            logger.error(f"Error handling client {client_id}: {e}")
        finally:
            logger.info(f"Connection closed for client {client_id}")

    def _should_generate_audio(self, text_buffer: str) -> bool:
        """Determine if we should generate audio from current buffer.
        
        This implements smart streaming logic:
        - Generate immediately when we have complete sentences (for natural speech)
        - Generate when buffer gets too long (prevent memory issues and reduce latency)
        - Don't generate for empty/whitespace-only buffers
        
        The goal is to balance speech naturalness with responsiveness.
        """
        if not text_buffer.strip():
            return False

        # Generate on sentence boundaries
        sentence_endings = ['.', '!', '?', '\n']
        if any(ending in text_buffer for ending in sentence_endings):
            return True

        # Generate if buffer gets too long
        return len(text_buffer) > 100

    def _extract_complete_sentences(self, text_buffer: str) -> str:
        """Extract complete sentences from buffer.
        
        This function finds natural breaking points in text to generate
        coherent speech. It prioritizes sentence boundaries, but falls back
        to word boundaries for very long sentences.
        
        Returns the text that should be synthesized now, leaving the
        remainder in the buffer for future processing.
        """
        sentence_endings = ['.', '!', '?', '\n']

        for ending in sentence_endings:
            if ending in text_buffer:
                idx = text_buffer.rfind(ending) + 1
                return text_buffer[:idx].strip()

        # If no sentence ending and buffer is long, split at word boundary
        # This prevents very long sentences from causing latency issues
        if len(text_buffer) > 100:
            words = text_buffer.split()
            if len(words) > 5:
                mid_point = len(words) // 2
                return ' '.join(words[:mid_point])

        return ""

    async def _generate_audio_stream(self, text: str) -> AsyncGenerator[TTSChunk, None]:
        """Generate audio stream for given text.
        
        This function implements the core streaming logic:
        1. Split text into optimal chunks for TTS processing
        2. Generate audio for each chunk
        3. Yield results as they're ready (don't wait for all chunks)
        
        The streaming approach reduces perceived latency and allows
        for real-time playback of long texts.
        """
        if not text.strip():
            return

        try:
            # Split text into manageable chunks
            # Smaller chunks = lower latency, but may affect speech quality
            # 50 characters is a good balance for most TTS engines
            chunks = self._split_text_smart(text, max_len=50)

            # Generate audio for each chunk
            for chunk_text in chunks:
                if not chunk_text.strip():
                    continue

                start_time = time.time()

                # Generate audio using Kokoro
                try:
                    audio, sr = await self._generate_kokoro_audio(chunk_text)

                    if audio is not None and audio.size > 0:
                        # Convert to required format
                        audio_b64 = audio_to_pcm16_base64(audio, sr)

                        # Calculate audio duration and generate alignment
                        audio_duration_ms = (len(audio) / sr) * 1000
                        alignment = alignment_for_text(chunk_text, audio_duration_ms)

                        generation_time = time.time() - start_time
                        logger.info(f"Generated audio for '{chunk_text[:30]}...' in {generation_time:.3f}s")

                        yield TTSChunk(
                            audio_b64=audio_b64,
                            alignment=alignment,
                            text=chunk_text
                        )
                    else:
                        logger.warning(f"No audio generated for text: '{chunk_text[:30]}...'")

                except Exception as e:
                    logger.error(f"Error generating audio for chunk '{chunk_text[:30]}...': {e}")
                    continue

        except Exception as e:
            logger.error(f"Error generating audio stream: {e}")

    async def _generate_kokoro_audio(self, text: str) -> Tuple[np.ndarray, int]:
        """Generate audio using Kokoro TTS with improved error handling.
        
        This function wraps the TTS engine call with:
        1. Thread pool execution (TTS engines often block)
        2. Robust output format handling
        3. Sample rate detection
        4. Error recovery
        
        The async wrapper prevents TTS generation from blocking
        the WebSocket event loop, allowing multiple clients.
        """
        try:
            # Run TTS in thread pool to avoid blocking event loop
            loop = asyncio.get_event_loop()
            with concurrent.futures.ThreadPoolExecutor() as executor:
                result = await loop.run_in_executor(
                    executor, lambda: self.pipeline(text, voice=self.voice)
                )

            # Get sample rate
            sr = int(getattr(self.pipeline, "sample_rate", getattr(self, "default_sr", 24000)))

            # Process the output safely
            audio = _process_kokoro_output(result)

            # If we got a tuple with sample rate info, use it
            if isinstance(result, tuple) and len(result) >= 2 and isinstance(result[1], (int, float)):
                sr = int(result[1])
            elif isinstance(result, dict) and 'sr' in result:
                sr = int(result['sr'])
            elif isinstance(result, dict) and 'sample_rate' in result:
                sr = int(result['sample_rate'])

            if audio.size == 0:
                logger.warning(f"Empty audio generated for text: '{text[:30]}...'")
                return np.array([], dtype=np.float32), sr

            logger.debug(f"Generated audio shape: {audio.shape}, sr: {sr}")
            return audio, sr

        except Exception as e:
            logger.error(f"Kokoro TTS error for text '{text[:30]}...': {e}")
            return np.array([], dtype=np.float32), int(getattr(self, "default_sr", 24000))

    def _split_text_smart(self, text: str, max_len: int = 50) -> List[str]:
        """Smart text splitting for optimal TTS processing.
        
        This function splits long text into chunks that:
        1. Respect sentence boundaries (most important for speech quality)
        2. Respect phrase boundaries (commas, semicolons)
        3. Stay under maximum length (for processing efficiency)
        
        The hierarchical splitting ensures natural-sounding speech
        even when dealing with very long input texts.
        """
        if not text:
            return []

        # Split on sentence boundaries first
        sentences = re.split(r'(?<=[.!?])\s+', text)
        chunks = []

        for sentence in sentences:
            if len(sentence) <= max_len:
                chunks.append(sentence)
            else:
                # Split long sentences at phrase boundaries
                phrases = re.split(r'(?<=[,;:])\s+', sentence)
                current_chunk = ""

                for phrase in phrases:
                    if len(current_chunk + " " + phrase) <= max_len:
                        current_chunk = (current_chunk + " " + phrase).strip()
                    else:
                        if current_chunk:
                            chunks.append(current_chunk)
                        current_chunk = phrase

                if current_chunk:
                    chunks.append(current_chunk)

        return [chunk.strip() for chunk in chunks if chunk.strip()]

    async def start_server(self):
        """Start the WebSocket server.
        
        This creates a persistent WebSocket server that can handle
        multiple simultaneous clients. Each client connection is
        handled in its own async task.
        """
        logger.info(f"Starting TTS WebSocket server on {self.host}:{self.port}")
        
        # Create WebSocket server with connection handler
        async with websockets.serve(self.handle_client, self.host, self.port):
            logger.info(f"TTS WebSocket server running on ws://{self.host}:{self.port}")
            # Run forever (until interrupted)
            await asyncio.Future()

In [12]:
# WebSocket Client for Testing
class TTSWebSocketClient:
    """Test client for the TTS WebSocket service.
    
    This class provides a programmatic interface for testing
    the TTS service. It implements the same protocol as web clients
    and can be used for automated testing or integration testing.
    """

    def __init__(self, uri="ws://localhost:8765"):
        self.uri = uri
        self.websocket = None

    async def connect(self):
        """Connect to the TTS service"""
        self.websocket = await websockets.connect(self.uri)
        logger.info(f"Connected to {self.uri}")

    async def send_text_chunk(self, text: str, flush: bool = False):
        """Send a text chunk to the service.
        
        Args:
            text: Text to synthesize
            flush: Whether to flush any buffered text
        """
        message = {"text": text, "flush": flush}
        await self.websocket.send(json.dumps(message))
        logger.debug(f"Sent: {message}")

    async def receive_audio_chunk(self) -> Optional[Dict]:
        """Receive an audio chunk from the service.
        
        Returns:
            Dictionary with 'audio' and 'alignment' keys, or None if connection closed
        """
        try:
            message = await self.websocket.recv()
            return json.loads(message)
        except websockets.exceptions.ConnectionClosed:
            return None

    async def disconnect(self):
        """Disconnect from the service"""
        await self.send_text_chunk("", False)  # Send empty string to close
        await self.websocket.close()

In [13]:
# Main execution
async def main():
    """Main function to run the TTS service.
    
    Can be run as either server or test client based on command line arguments.
    """
    import sys
    
    if len(sys.argv) > 1 and sys.argv[1] == "client":
        # Run test client
        await run_test_client()
    else:
        # Run server
        service = WebSocketTTSService(host="0.0.0.0", port=8765)
        await service.start_server()

In [14]:
async def run_test_client():
    """Run a test client.
    
    This demonstrates the full protocol flow:
    1. Connect and send initial space
    2. Send text chunks 
    3. Receive and process audio chunks
    4. Disconnect properly
    """
    client = TTSWebSocketClient("ws://localhost:8765")
    await client.connect()
    
    # Send initial space (protocol requirement)
    await client.send_text_chunk(" ")
    
    # Test text with mathematical notation
    test_texts = [
        "Hello, this is a test of the TTS system. ",
        "The equation \\(a^2 + b^2 = c^2\\) is the Pythagorean theorem. ",
        "The derivative of \\(e^x\\) is \\(\\frac{d}{dx} e^x = e^x\\). ",
        ""  # Final empty string to close
    ]
    
    # Start receiving audio in background task
    async def receive_audio():
        while True:
            chunk = await client.receive_audio_chunk()
            if chunk is None:
                break
            logger.info(f"Received audio chunk: {len(chunk['audio'])} bytes, "
                       f"{len(chunk['alignment']['chars'])} characters, "
                       f"Text: '{chunk['alignment']['chars'][0:20]}...'")

    receive_task = asyncio.create_task(receive_audio())
    
    # Send text chunks with realistic typing delays
    for i, text in enumerate(test_texts[:-1]):
        await client.send_text_chunk(text)
        await asyncio.sleep(1)  # Simulate typing delay
    
    # Send final disconnect message
    await client.send_text_chunk("", False)
    
    # Wait for all audio to be received
    await receive_task
    await client.disconnect()

In [None]:
if __name__ == "__main__":
    asyncio.run(main())