In [1]:
# Install required packages when running interactively (already managed via pyproject.toml)
%pip install transformers datasets soundfile librosa --quiet

Note: you may need to restart the kernel to use updated packages.


In [2]:
import io
import logging
import zipfile
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import lru_cache
from typing import Dict, Optional

import librosa
import numpy as np
import soundfile as sf
import torch
from huggingface_hub import hf_hub_download
from transformers import (
    SpeechT5ForTextToSpeech,
    SpeechT5HifiGan,
    SpeechT5Processor,
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("polylingua.tts")
logger.setLevel(logging.INFO)
logging.getLogger("transformers").setLevel(logging.WARNING)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
@dataclass
class TtsRequest:
    """Container describing a text-to-speech request."""
    text: str
    language: str = "en"
    voice_id: Optional[str] = None
    style: Optional[str] = None
    speed: float = 1.0

    def __post_init__(self) -> None:
        self.text = (self.text or "").strip()
        if not self.text:
            raise ValueError("TtsRequest.text must be a non-empty string")
        self.language = (self.language or "en").lower()
        if self.speed <= 0:
            raise ValueError("TtsRequest.speed must be greater than zero")

@dataclass
class TtsResult:
    """Result payload returned by a TTS engine."""
    audio_bytes: bytes
    audio_format: str
    duration: float

    def save(self, path: str) -> None:
        """Persist the synthesized audio to disk."""
        with open(path, "wb") as handle:
            handle.write(self.audio_bytes)

In [4]:
class BaseTtsEngine(ABC):
    """Abstract interface for text-to-speech providers."""

    name: str = "base"

    @abstractmethod
    def synthesize(self, tts_request: TtsRequest) -> TtsResult:
        """Convert text to speech and return the resulting audio."""
        raise NotImplementedError

    def supports_language(self, language: str) -> bool:
        return True


class SpeechT5Engine(BaseTtsEngine):
    """Hugging Face SpeechT5 implementation supporting English voices."""

    name = "speecht5"
    _AVAILABLE_VOICES = {"awb", "clb", "rms", "slt", "bdl"}

    def __init__(self, default_voice: str = "slt", device: Optional[str] = None):
        self.default_voice = default_voice if default_voice in self._AVAILABLE_VOICES else "slt"
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.sample_rate = 16000

    @staticmethod
    @lru_cache(maxsize=4)
    def _load_models(device: str):
        processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
        model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
        vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
        return processor, model, vocoder

    @staticmethod
    @lru_cache(maxsize=1)
    def _load_speaker_embeddings() -> Dict[str, torch.Tensor]:
        try:
            archive_path = hf_hub_download(
                repo_id="Matthijs/cmu-arctic-xvectors",
                filename="spkrec-xvect.zip",
                repo_type="dataset",
            )
        except Exception as exc:
            raise RuntimeError(
                "Unable to download SpeechT5 speaker embeddings archive from Hugging Face."
            ) from exc

        speaker_sums: Dict[str, torch.Tensor] = {}
        speaker_counts: Dict[str, int] = {}
        with zipfile.ZipFile(archive_path) as archive:
            for file_name in archive.namelist():
                if not file_name.endswith(".npy"):
                    continue
                base_name = file_name.rsplit("/", 1)[-1]
                speaker_identifier = base_name.split("-")[0]
                speaker_id = (
                    speaker_identifier.replace("cmu_us_", "").replace("_arctic", "")
                )
                if speaker_id not in SpeechT5Engine._AVAILABLE_VOICES:
                    continue
                vector_np = np.load(io.BytesIO(archive.read(file_name)))
                vector = torch.from_numpy(vector_np).float()
                if speaker_id not in speaker_sums:
                    speaker_sums[speaker_id] = vector.clone()
                    speaker_counts[speaker_id] = 1
                else:
                    speaker_sums[speaker_id] += vector
                    speaker_counts[speaker_id] += 1

        embeddings = {}
        for speaker_id, sum_vector in speaker_sums.items():
            count = speaker_counts[speaker_id]
            if count == 0:
                continue
            embeddings[speaker_id] = (sum_vector / count).unsqueeze(0)

        if not embeddings:
            raise RuntimeError("No speaker embeddings could be extracted from archive")
        return embeddings

    def _pick_voice(self, requested_voice: Optional[str]) -> str:
        if requested_voice and requested_voice.lower() in self._AVAILABLE_VOICES:
            return requested_voice.lower()
        return self.default_voice

    def supports_language(self, language: str) -> bool:
        return language.lower().startswith("en")

    def synthesize(self, tts_request: TtsRequest) -> TtsResult:
        voice_id = self._pick_voice(tts_request.voice_id)
        processor, model, vocoder = self._load_models(self.device)
        speaker_embeddings = self._load_speaker_embeddings()
        speaker_vector = speaker_embeddings.get(voice_id)
        if speaker_vector is None:
            logger.warning("Voice '%s' unavailable, defaulting to '%s'", voice_id, self.default_voice)
            speaker_vector = speaker_embeddings[self.default_voice]
            voice_id = self.default_voice
        speaker_vector = speaker_vector.to(model.device)
        inputs = processor(text=tts_request.text, return_tensors="pt")
        input_ids = inputs["input_ids"].to(model.device)
        speech = model.generate_speech(
            input_ids=input_ids,
            speaker_embeddings=speaker_vector,
            vocoder=vocoder,
        )
        audio = speech.cpu().numpy()
        if abs(tts_request.speed - 1.0) > 0.01:
            try:
                audio = librosa.effects.time_stretch(audio, rate=tts_request.speed)
            except Exception as exc:
                logger.warning("Time-stretch failed (%s); returning unmodified audio.", exc)
        buffer = io.BytesIO()
        sf.write(buffer, audio, self.sample_rate, format="WAV")
        duration = audio.shape[-1] / float(self.sample_rate)
        return TtsResult(audio_bytes=buffer.getvalue(), audio_format="wav", duration=duration)

In [5]:
class TtsRouterService:
    """Routes TTS requests to the most appropriate engine."""

    def __init__(
        self,
        engines: Optional[Dict[str, BaseTtsEngine]] = None,
        default_language: str = "en",
    ):
        engines = engines or {default_language.lower(): SpeechT5Engine()}
        self.engines = {lang.lower(): engine for lang, engine in engines.items()}
        self.default_language = default_language.lower()

    def register_engine(self, language: str, engine: BaseTtsEngine) -> None:
        self.engines[language.lower()] = engine

    def list_languages(self) -> Dict[str, str]:
        return {lang: engine.name for lang, engine in self.engines.items()}

    def synthesize_with_best_engine(self, tts_request: TtsRequest) -> TtsResult:
        language = (tts_request.language or self.default_language).lower()
        engine = self.engines.get(language)
        if engine is None:
            for candidate_language, candidate_engine in self.engines.items():
                if candidate_engine.supports_language(language):
                    engine = candidate_engine
                    break
        if engine is None:
            engine = self.engines.get(self.default_language)
        if engine is None:
            raise RuntimeError("No TTS engine registered for requested language")
        logger.info("Synthesizing with %s for language '%s'", engine.name, language)
        return engine.synthesize(tts_request)

In [6]:
# Optional smoke test: uncomment to generate a short TTS sample.
# Keep in mind the first invocation downloads pretrained weights.
#
router = TtsRouterService()
demo_request = TtsRequest(text="Hello from PolyLingua's TTS module.", language="en")
demo_result = router.synthesize_with_best_engine(demo_request)
output_path = "sample_tts.wav"
demo_result.save(output_path)
print(f"Generated {output_path} ({demo_result.duration:.2f}s)")

INFO:polylingua.tts:Synthesizing with speecht5 for language 'en'


Generated sample_tts.wav (2.34s)
