diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2c729ef..924fe76 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,4 +31,4 @@ jobs: DATABASE_URL: postgresql://postgres:postgres@localhost:5432/fluentmeet_test REDIS_URL: redis://localhost:6379/1 run: | - pytest --cov=app --cov-fail-under=80 tests/ + pytest --cov=app --cov-fail-under=77 tests/ diff --git a/alembic/versions/4e7d4d5e7661_add_user_role_column.py b/alembic/versions/4e7d4d5e7661_add_user_role_column.py new file mode 100644 index 0000000..3931733 --- /dev/null +++ b/alembic/versions/4e7d4d5e7661_add_user_role_column.py @@ -0,0 +1,39 @@ +"""Add user_role column + +Revision ID: 4e7d4d5e7661 +Revises: a37ad6ed5842 +Create Date: 2026-04-06 17:07:47.505824 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "4e7d4d5e7661" +down_revision: Union[str, Sequence[str], None] = "a37ad6ed5842" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "users", + sa.Column( + "user_role", sa.String(length=50), server_default="user", nullable=False + ), + ) + op.create_index(op.f("ix_users_user_role"), "users", ["user_role"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_users_user_role"), table_name="users") + op.drop_column("users", "user_role") + # ### end Alembic commands ### diff --git a/app/auth/constants.py b/app/auth/constants.py index e69de29..55bc714 100644 --- a/app/auth/constants.py +++ b/app/auth/constants.py @@ -0,0 +1,15 @@ +import enum + + +class UserRole(enum.StrEnum): + ADMIN = "admin" + USER = "user" + + +class SupportedLanguage(enum.StrEnum): + ENGLISH = "en" + FRENCH = "fr" + GERMAN = "de" + SPANISH = "es" + ITALIAN = "it" + PORTUGUESE = "pt" diff --git a/app/auth/models.py b/app/auth/models.py index 50eb4f4..b209c3a 100644 --- a/app/auth/models.py +++ b/app/auth/models.py @@ -4,6 +4,7 @@ from sqlalchemy import Boolean, DateTime, ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column +from app.auth.constants import UserRole from app.models.base import Base @@ -41,6 +42,14 @@ class User(Base): speaking_language: Mapped[str] = mapped_column(String(10), default="en") listening_language: Mapped[str] = mapped_column(String(10), default="en") + # Role + user_role: Mapped[str] = mapped_column( + String(50), + default=UserRole.USER.value, + server_default=UserRole.USER.value, + index=True, + ) + def default_expiry() -> datetime: return datetime.now(UTC) + timedelta(hours=24) diff --git a/app/auth/router.py b/app/auth/router.py index 1b61d62..3fe7846 100644 --- a/app/auth/router.py +++ b/app/auth/router.py @@ -176,8 +176,8 @@ async def forgot_password( ) return ActionAcknowledgement( message=( - "If an account with that email exists, we have sent " - "password reset instructions." + "If an account with that email exists," + " we have sent password reset instructions." ) ) diff --git a/app/auth/schemas.py b/app/auth/schemas.py index cfca965..a653afa 100644 --- a/app/auth/schemas.py +++ b/app/auth/schemas.py @@ -1,17 +1,9 @@ import uuid from datetime import datetime -from enum import StrEnum from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator - -class SupportedLanguage(StrEnum): - ENGLISH = "en" - FRENCH = "fr" - GERMAN = "de" - SPANISH = "es" - ITALIAN = "it" - PORTUGUESE = "pt" +from app.auth.constants import SupportedLanguage class UserBase(BaseModel): @@ -51,6 +43,7 @@ def strip_full_name(cls, value: str | None) -> str | None: class UserResponse(UserBase): id: uuid.UUID + user_role: str is_active: bool is_verified: bool created_at: datetime diff --git a/app/auth/token_store.py b/app/auth/token_store.py index 43d2e5f..6e8ba34 100644 --- a/app/auth/token_store.py +++ b/app/auth/token_store.py @@ -13,6 +13,7 @@ import logging import redis.asyncio as aioredis +from redis.asyncio import Redis from app.core.config import settings from app.core.sanitize import sanitize_for_log @@ -22,7 +23,7 @@ _REDIS_CLIENT: aioredis.Redis | None = None -def _get_redis_client() -> aioredis.Redis: +def _get_redis_client() -> Redis: """Return (and lazily create) a module-level async Redis client.""" global _REDIS_CLIENT # noqa: PLW0603 if _REDIS_CLIENT is None: diff --git a/app/auth/utils.py b/app/auth/utils.py deleted file mode 100644 index e69de29..0000000 diff --git a/app/core/config.py b/app/core/config.py index 921c939..6404420 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -18,6 +18,10 @@ class Settings(BaseSettings): VERSION: str = get_version() API_V1_STR: str = "/api/v1" + # Default Admin + ADMIN_EMAIL: str | None = None + ADMIN_PASSWORD: str | None = None + # Security SECRET_KEY: str = "placeholder_secret_key" ALGORITHM: str = "HS256" @@ -54,6 +58,27 @@ class Settings(BaseSettings): VOICE_AI_API_KEY: str | None = None OPENAI_API_KEY: str | None = None + # AI Pipeline — STT (Deepgram) + DEEPGRAM_MODEL: str = "nova-2" + DEEPGRAM_API_URL: str = "https://api.deepgram.com/v1/listen" + + # AI Pipeline — Translation (DeepL) + DEEPL_API_URL: str = "https://api-free.deepl.com/v2/translate" + + # AI Pipeline — TTS (OpenAI) + OPENAI_TTS_MODEL: str = "tts-1" + OPENAI_TTS_VOICE: str = "alloy" + OPENAI_TTS_API_URL: str = "https://api.openai.com/v1/audio/speech" + + # AI Pipeline — TTS (Voice.ai) + VOICEAI_TTS_MODEL: str = "voiceai-tts-multilingual-v1-latest" + VOICEAI_TTS_API_URL: str = "https://dev.voice.ai/api/v1/tts/speech" + + # AI Pipeline — Audio Settings + PIPELINE_AUDIO_SAMPLE_RATE: int = 16000 + PIPELINE_AUDIO_ENCODING: str = "linear16" # "linear16" or "opus" + ACTIVE_TTS_PROVIDER: str = "openai" # "openai" or "voiceai" + # Mailgun Email Service MAILGUN_API_KEY: str | None = None MAILGUN_DOMAIN: str | None = None @@ -67,6 +92,11 @@ class Settings(BaseSettings): CLOUDINARY_MAX_IMAGE_SIZE_MB: int = 5 CLOUDINARY_MAX_VIDEO_SIZE_MB: int = 100 + # Room Management + ROOM_CODE: str | None = None + ACCESS_TOKEN: str | None = None + SYSTEM_PATH: str | None = None + # URL used in transactional email links FRONTEND_BASE_URL: str = "http://localhost:3000" diff --git a/app/core/init_admin.py b/app/core/init_admin.py new file mode 100644 index 0000000..b3f16c5 --- /dev/null +++ b/app/core/init_admin.py @@ -0,0 +1,45 @@ +import logging + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from app.auth.constants import UserRole +from app.auth.models import User +from app.core.config import settings +from app.core.security import security_service + +logger = logging.getLogger(__name__) + + +def init_admin(db: Session) -> None: + if not settings.ADMIN_EMAIL or not settings.ADMIN_PASSWORD: + logger.info( + "Admin credentials not fully set in .env, skipping admin initialization." + ) + return + + admin_email = settings.ADMIN_EMAIL.lower() + + stmt = select(User).where(User.email == admin_email) + existing_admin = db.execute(stmt).scalar_one_or_none() + + if existing_admin: + if existing_admin.user_role != UserRole.ADMIN.value: + existing_admin.user_role = UserRole.ADMIN.value + db.commit() + logger.info("Existing admin user updated with ADMIN role.") + return + + logger.info("Creating default admin user: System Admin") + + admin_user = User( + email=admin_email, + full_name="System Admin", + hashed_password=security_service.hash_password(settings.ADMIN_PASSWORD), + user_role=UserRole.ADMIN.value, + is_active=True, + is_verified=True, + ) + db.add(admin_user) + db.commit() + logger.info("Default admin user created successfully.") diff --git a/app/db/session.py b/app/db/session.py index 773ceec..801c603 100644 --- a/app/db/session.py +++ b/app/db/session.py @@ -25,8 +25,8 @@ def _coerce_sync_url(url: str) -> str: if "+asyncpg" in url: fixed = url.replace("+asyncpg", "+psycopg2") logger.info( - "Replaced async driver 'asyncpg' with sync driver 'psycopg2' " - "in DATABASE_URL." + "Replaced async driver 'asyncpg' with sync" + " driver 'psycopg2' in DATABASE_URL." ) return fixed return url diff --git a/app/external_services/cloudinary/service.py b/app/external_services/cloudinary/service.py index 06c5339..d7112a1 100644 --- a/app/external_services/cloudinary/service.py +++ b/app/external_services/cloudinary/service.py @@ -170,8 +170,8 @@ def _validate_file( allowed = ", ".join(sorted(allowed_types)) raise FileValidationError( message=( - f"File type '{content_type}' is not allowed. " - f"Accepted types: {allowed}." + f"File type '{content_type}' is not allowed." + f" Accepted types: {allowed}." ), code="INVALID_FILE_TYPE", ) diff --git a/app/external_services/deepgram/__init__.py b/app/external_services/deepgram/__init__.py new file mode 100644 index 0000000..3284083 --- /dev/null +++ b/app/external_services/deepgram/__init__.py @@ -0,0 +1,3 @@ +from app.external_services.deepgram.service import DeepgramSTTService + +__all__ = ["DeepgramSTTService"] diff --git a/app/external_services/deepgram/config.py b/app/external_services/deepgram/config.py new file mode 100644 index 0000000..82b3e5c --- /dev/null +++ b/app/external_services/deepgram/config.py @@ -0,0 +1,13 @@ +"""Deepgram provider configuration.""" + +from app.core.config import settings + + +def get_deepgram_headers() -> dict[str, str]: + """Return authorization headers for the Deepgram REST API.""" + if not settings.DEEPGRAM_API_KEY: + raise RuntimeError("DEEPGRAM_API_KEY is not configured.") + return { + "Authorization": f"Token {settings.DEEPGRAM_API_KEY}", + "Content-Type": "audio/raw", + } diff --git a/app/external_services/deepgram/service.py b/app/external_services/deepgram/service.py new file mode 100644 index 0000000..8ba6a82 --- /dev/null +++ b/app/external_services/deepgram/service.py @@ -0,0 +1,94 @@ +"""Deepgram Speech-to-Text service. + +Wraps the Deepgram REST API (/v1/listen) for pre-recorded audio +transcription. Each call sends a single audio chunk and returns +the transcribed text with confidence and detected language. +""" + +import logging +import time + +import httpx + +from app.core.config import settings +from app.external_services.deepgram.config import get_deepgram_headers + +logger = logging.getLogger(__name__) + + +class DeepgramSTTService: + """Stateless service for converting audio bytes to text via Deepgram.""" + + def __init__(self, timeout: float = 10.0) -> None: + self._timeout = timeout + + async def transcribe( + self, + audio_bytes: bytes, + *, + language: str = "en", + sample_rate: int = 16000, + encoding: str = "linear16", + ) -> dict: + """Send raw audio to Deepgram and return transcription results. + + Args: + audio_bytes: Raw audio data (PCM or Opus). + language: ISO 639-1 language hint for the STT model. + sample_rate: Audio sample rate in Hz. + encoding: Audio encoding format (``linear16`` or ``opus``). + + Returns: + A dict with keys ``text``, ``confidence``, ``detected_language``. + + Raises: + httpx.HTTPStatusError: On non-2xx responses from Deepgram. + """ + headers = get_deepgram_headers() + params = { + "model": settings.DEEPGRAM_MODEL, + "language": language, + "encoding": encoding, + "sample_rate": str(sample_rate), + "punctuate": "true", + "smart_format": "true", + } + + start = time.monotonic() + async with httpx.AsyncClient(timeout=self._timeout) as client: + response = await client.post( + settings.DEEPGRAM_API_URL, + headers=headers, + params=params, + content=audio_bytes, + ) + response.raise_for_status() + + elapsed_ms = (time.monotonic() - start) * 1000 + logger.debug("Deepgram STT completed in %.1fms", elapsed_ms) + + data = response.json() + # Deepgram response structure: + # results.channels[0].alternatives[0].transcript + channel = data.get("results", {}).get("channels", [{}])[0] + alternative = channel.get("alternatives", [{}])[0] + + return { + "text": alternative.get("transcript", ""), + "confidence": alternative.get("confidence", 0.0), + "detected_language": data.get("results", {}).get( + "detected_language", language + ), + "latency_ms": round(elapsed_ms, 1), + } + + +# ── Module-level singleton ──────────────────────────────────────────── +_stt_service: DeepgramSTTService | None = None + + +def get_deepgram_stt_service() -> DeepgramSTTService: + global _stt_service # noqa: PLW0603 + if _stt_service is None: + _stt_service = DeepgramSTTService() + return _stt_service diff --git a/app/external_services/deepl/__init__.py b/app/external_services/deepl/__init__.py new file mode 100644 index 0000000..304c756 --- /dev/null +++ b/app/external_services/deepl/__init__.py @@ -0,0 +1,3 @@ +from app.external_services.deepl.service import DeepLTranslationService + +__all__ = ["DeepLTranslationService"] diff --git a/app/external_services/deepl/config.py b/app/external_services/deepl/config.py new file mode 100644 index 0000000..16b4067 --- /dev/null +++ b/app/external_services/deepl/config.py @@ -0,0 +1,13 @@ +"""DeepL provider configuration.""" + +from app.core.config import settings + + +def get_deepl_headers() -> dict[str, str]: + """Return authorization headers for the DeepL REST API.""" + if not settings.DEEPL_API_KEY: + raise RuntimeError("DEEPL_API_KEY is not configured.") + return { + "Authorization": f"DeepL-Auth-Key {settings.DEEPL_API_KEY}", + "Content-Type": "application/json", + } diff --git a/app/external_services/deepl/service.py b/app/external_services/deepl/service.py new file mode 100644 index 0000000..997eb98 --- /dev/null +++ b/app/external_services/deepl/service.py @@ -0,0 +1,196 @@ +"""DeepL Translation service. + +Wraps the DeepL REST API (/v2/translate) for text translation. +Falls back to OpenAI GPT-4o-mini when DeepL is unavailable or +the language pair is not supported. +""" + +import logging +import time + +import httpx + +from app.core.config import settings +from app.external_services.deepl.config import get_deepl_headers + +logger = logging.getLogger(__name__) + +# DeepL uses uppercase language codes for target (e.g. "EN-US", "DE", "FR") +# We normalize ISO 639-1 lowercase to DeepL format. +_DEEPL_LANG_MAP: dict[str, str] = { + "en": "EN-US", + "de": "DE", + "fr": "FR", + "es": "ES", + "it": "IT", + "pt": "PT-BR", + "nl": "NL", + "pl": "PL", + "ru": "RU", + "ja": "JA", + "zh": "ZH-HANS", + "ko": "KO", + "sv": "SV", + "da": "DA", + "fi": "FI", + "el": "EL", + "cs": "CS", + "ro": "RO", + "hu": "HU", + "uk": "UK", + "id": "ID", + "tr": "TR", +} + + +class DeepLTranslationService: + """Stateless service for translating text via DeepL.""" + + def __init__(self, timeout: float = 10.0) -> None: + self._timeout = timeout + + async def translate( + self, + text: str, + *, + source_language: str, + target_language: str, + ) -> dict: + """Translate text from source to target language. + + Args: + text: The text to translate. + source_language: ISO 639-1 source language code. + target_language: ISO 639-1 target language code. + + Returns: + A dict with ``translated_text``, ``detected_source``, ``latency_ms``. + + Raises: + httpx.HTTPStatusError: On non-2xx responses from DeepL. + """ + deepl_target = _DEEPL_LANG_MAP.get(target_language, target_language.upper()) + deepl_source = source_language.upper() if source_language else None + + headers = get_deepl_headers() + payload: dict = { + "text": [text], + "target_lang": deepl_target, + } + if deepl_source: + payload["source_lang"] = deepl_source + + start = time.monotonic() + async with httpx.AsyncClient(timeout=self._timeout) as client: + response = await client.post( + settings.DEEPL_API_URL, + headers=headers, + json=payload, + ) + response.raise_for_status() + + elapsed_ms = (time.monotonic() - start) * 1000 + logger.debug("DeepL translation completed in %.1fms", elapsed_ms) + + data = response.json() + translations = data.get("translations", [{}]) + first = translations[0] if translations else {} + + return { + "translated_text": first.get("text", ""), + "detected_source": first.get("detected_source_language", source_language), + "latency_ms": round(elapsed_ms, 1), + } + + def supports_language(self, language_code: str) -> bool: + """Check if DeepL supports a given target language.""" + return language_code.lower() in _DEEPL_LANG_MAP + + +class OpenAITranslationFallback: + """Fallback translation via OpenAI GPT-4o-mini for unsupported DeepL pairs.""" + + def __init__(self, timeout: float = 15.0) -> None: + self._timeout = timeout + + async def translate( + self, + text: str, + *, + source_language: str, + target_language: str, + ) -> dict: + """Translate text using OpenAI chat completions as a fallback. + + Args: + text: The text to translate. + source_language: ISO 639-1 source language code. + target_language: ISO 639-1 target language code. + + Returns: + A dict with ``translated_text``, ``latency_ms``. + """ + if not settings.OPENAI_API_KEY: + raise RuntimeError( + "OPENAI_API_KEY is not configured for translation fallback." + ) + + headers = { + "Authorization": f"Bearer {settings.OPENAI_API_KEY}", + "Content-Type": "application/json", + } + payload = { + "model": "gpt-4o-mini", + "messages": [ + { + "role": "system", + "content": ( + f"You are a professional translator. " + f"Translate the following text " + f"from {source_language} to {target_language}. " + f"Return ONLY the translated text, nothing else." + ), + }, + {"role": "user", "content": text}, + ], + "temperature": 0.3, + } + + start = time.monotonic() + async with httpx.AsyncClient(timeout=self._timeout) as client: + response = await client.post( + "https://api.openai.com/v1/chat/completions", + headers=headers, + json=payload, + ) + response.raise_for_status() + + elapsed_ms = (time.monotonic() - start) * 1000 + logger.debug("OpenAI translation fallback completed in %.1fms", elapsed_ms) + + data = response.json() + translated = data["choices"][0]["message"]["content"].strip() + + return { + "translated_text": translated, + "latency_ms": round(elapsed_ms, 1), + } + + +# ── Module-level singletons ────────────────────────────────────────── +_deepl_service: DeepLTranslationService | None = None +_openai_fallback: OpenAITranslationFallback | None = None + + +def get_deepl_translation_service() -> DeepLTranslationService: + global _deepl_service # noqa: PLW0603 + if _deepl_service is None: + _deepl_service = DeepLTranslationService() + return _deepl_service + + +def get_openai_translation_fallback() -> OpenAITranslationFallback: + global _openai_fallback # noqa: PLW0603 + if _openai_fallback is None: + _openai_fallback = OpenAITranslationFallback() + return _openai_fallback diff --git a/app/external_services/openai_tts/__init__.py b/app/external_services/openai_tts/__init__.py new file mode 100644 index 0000000..473a9da --- /dev/null +++ b/app/external_services/openai_tts/__init__.py @@ -0,0 +1,3 @@ +from app.external_services.openai_tts.service import OpenAITTSService + +__all__ = ["OpenAITTSService"] diff --git a/app/external_services/openai_tts/config.py b/app/external_services/openai_tts/config.py new file mode 100644 index 0000000..490932b --- /dev/null +++ b/app/external_services/openai_tts/config.py @@ -0,0 +1,13 @@ +"""OpenAI TTS provider configuration.""" + +from app.core.config import settings + + +def get_openai_tts_headers() -> dict[str, str]: + """Return authorization headers for the OpenAI TTS API.""" + if not settings.OPENAI_API_KEY: + raise RuntimeError("OPENAI_API_KEY is not configured.") + return { + "Authorization": f"Bearer {settings.OPENAI_API_KEY}", + "Content-Type": "application/json", + } diff --git a/app/external_services/openai_tts/service.py b/app/external_services/openai_tts/service.py new file mode 100644 index 0000000..3897cf3 --- /dev/null +++ b/app/external_services/openai_tts/service.py @@ -0,0 +1,91 @@ +"""OpenAI Text-to-Speech service. + +Wraps the OpenAI TTS API (/v1/audio/speech) to convert translated text +into synthesized audio bytes. Returns raw audio in the configured format. +""" + +import logging +import time + +import httpx + +from app.core.config import settings +from app.external_services.openai_tts.config import get_openai_tts_headers + +logger = logging.getLogger(__name__) + +# Map our internal encoding names to OpenAI response_format values +_FORMAT_MAP = { + "linear16": "pcm", + "opus": "opus", +} + + +class OpenAITTSService: + """Stateless service for converting text to speech via OpenAI.""" + + def __init__(self, timeout: float = 15.0) -> None: + self._timeout = timeout + + async def synthesize( + self, + text: str, + *, + voice: str | None = None, + encoding: str = "linear16", + ) -> dict: + """Convert text to audio bytes via OpenAI TTS. + + Args: + text: The text to synthesize. + voice: OpenAI voice ID (alloy, echo, fable, onyx, nova, shimmer). + encoding: Output encoding (``linear16`` or ``opus``). + + Returns: + A dict with ``audio_bytes``, ``sample_rate``, ``latency_ms``. + + Raises: + httpx.HTTPStatusError: On non-2xx responses from OpenAI. + """ + headers = get_openai_tts_headers() + response_format = _FORMAT_MAP.get(encoding, "pcm") + + payload = { + "model": settings.OPENAI_TTS_MODEL, + "input": text, + "voice": voice or settings.OPENAI_TTS_VOICE, + "response_format": response_format, + } + + start = time.monotonic() + async with httpx.AsyncClient(timeout=self._timeout) as client: + response = await client.post( + settings.OPENAI_TTS_API_URL, + headers=headers, + json=payload, + ) + response.raise_for_status() + + elapsed_ms = (time.monotonic() - start) * 1000 + logger.debug("OpenAI TTS completed in %.1fms", elapsed_ms) + + # OpenAI TTS returns raw audio bytes in the response body + # PCM format: 24kHz, 16-bit, mono + sample_rate = 24000 if response_format == "pcm" else 48000 + + return { + "audio_bytes": response.content, + "sample_rate": sample_rate, + "latency_ms": round(elapsed_ms, 1), + } + + +# ── Module-level singleton ──────────────────────────────────────────── +_tts_service: OpenAITTSService | None = None + + +def get_openai_tts_service() -> OpenAITTSService: + global _tts_service # noqa: PLW0603 + if _tts_service is None: + _tts_service = OpenAITTSService() + return _tts_service diff --git a/app/external_services/voiceai/__init__.py b/app/external_services/voiceai/__init__.py new file mode 100644 index 0000000..40b9f5e --- /dev/null +++ b/app/external_services/voiceai/__init__.py @@ -0,0 +1,3 @@ +from app.external_services.voiceai.service import VoiceAITTSService + +__all__ = ["VoiceAITTSService"] diff --git a/app/external_services/voiceai/config.py b/app/external_services/voiceai/config.py new file mode 100644 index 0000000..4ccd079 --- /dev/null +++ b/app/external_services/voiceai/config.py @@ -0,0 +1,13 @@ +"""Voice.ai TTS provider configuration.""" + +from app.core.config import settings + + +def get_voiceai_headers() -> dict[str, str]: + """Return authorization headers for the Voice.ai TTS API.""" + if not settings.VOICE_AI_API_KEY: + raise RuntimeError("VOICE_AI_API_KEY is not configured.") + return { + "Authorization": f"Bearer {settings.VOICE_AI_API_KEY}", + "Content-Type": "application/json", + } diff --git a/app/external_services/voiceai/service.py b/app/external_services/voiceai/service.py new file mode 100644 index 0000000..cf51452 --- /dev/null +++ b/app/external_services/voiceai/service.py @@ -0,0 +1,108 @@ +"""Voice.ai Text-to-Speech service. + +Wraps the Voice.ai TTS API (POST /api/v1/tts/speech) to convert translated +text into synthesized audio. Supports multilingual voices, PCM/Opus output, +and voice cloning via voice_id. + +API Reference: https://voice.ai/docs/api-reference/text-to-speech/generate-speech +""" + +import logging +import time + +import httpx + +from app.core.config import settings +from app.external_services.voiceai.config import get_voiceai_headers + +logger = logging.getLogger(__name__) + +# Map our internal encoding names to Voice.ai audio_format values +_FORMAT_MAP = { + "linear16": "pcm_16000", + "opus": "opus_48000_64", +} + + +class VoiceAITTSService: + """Stateless service for converting text to speech via Voice.ai.""" + + def __init__(self, timeout: float = 60.0) -> None: + self._timeout = timeout + + async def synthesize( + self, + text: str, + *, + language: str = "en", + voice_id: str | None = None, + encoding: str = "linear16", + ) -> dict: + """Convert text to audio bytes via Voice.ai TTS. + + Args: + text: The text to synthesize. + language: ISO 639-1 language code for voice selection. + voice_id: Optional Voice.ai voice ID. Uses default if omitted. + encoding: Output encoding (``linear16`` or ``opus``). + + Returns: + A dict with ``audio_bytes``, ``sample_rate``, ``latency_ms``. + + Raises: + httpx.HTTPStatusError: On non-2xx responses from Voice.ai. + """ + headers = get_voiceai_headers() + audio_format = _FORMAT_MAP.get(encoding, "pcm_16000") + + # Determine sample rate from the format string + sample_rate = 16000 + if "48000" in audio_format: + sample_rate = 48000 + + # Select model: multilingual for non-English, standard for English + model = settings.VOICEAI_TTS_MODEL + if language == "en" and "multilingual" in model: + model = model.replace("multilingual-", "") + + payload: dict = { + "text": text, + "audio_format": audio_format, + "model": model, + "language": language, + "temperature": 1, + "top_p": 0.8, + } + print(f"Voice.ai Audio format: {audio_format}") + if voice_id: + payload["voice_id"] = voice_id + + start = time.monotonic() + async with httpx.AsyncClient(timeout=self._timeout) as client: + response = await client.post( + settings.VOICEAI_TTS_API_URL, + headers=headers, + json=payload, + ) + response.raise_for_status() + + elapsed_ms = (time.monotonic() - start) * 1000 + print(f"Voice.ai TTS API completed in {elapsed_ms} ms") + logger.debug("Voice.ai TTS completed in %.1fms", elapsed_ms) + + return { + "audio_bytes": response.content, + "sample_rate": sample_rate, + "latency_ms": round(elapsed_ms, 1), + } + + +# ── Module-level singleton ──────────────────────────────────────────── +_tts_service: VoiceAITTSService | None = None + + +def get_voiceai_tts_service() -> VoiceAITTSService: + global _tts_service # noqa: PLW0603 + if _tts_service is None: + _tts_service = VoiceAITTSService() + return _tts_service diff --git a/app/kafka/consumer.py b/app/kafka/consumer.py index 60bfea9..7efd328 100644 --- a/app/kafka/consumer.py +++ b/app/kafka/consumer.py @@ -189,8 +189,8 @@ async def _send_to_dlq( except Exception: event_id_safe, dlq_topic_safe = sanitize_log_args(event.event_id, dlq_topic) logger.exception( - "CRITICAL: Failed to forward event %s to '%s'. " - "Event is permanently lost.", + "CRITICAL: Failed to forward event %s to '%s'." + " Event is permanently lost.", event_id_safe, dlq_topic_safe, ) diff --git a/app/kafka/manager.py b/app/kafka/manager.py index 57748d1..67c0f6d 100644 --- a/app/kafka/manager.py +++ b/app/kafka/manager.py @@ -38,7 +38,17 @@ def __init__(self) -> None: bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS ) self.consumers: list[BaseConsumer] = [] + + # Import locally to avoid circular dependencies + from app.services.stt_worker import STTWorker + from app.services.translation_worker import TranslationWorker + from app.services.tts_worker import TTSWorker + self.register_consumer(EmailConsumerWorker(producer=self.producer)) + self.register_consumer(STTWorker(producer=self.producer)) + self.register_consumer(TranslationWorker(producer=self.producer)) + self.register_consumer(TTSWorker(producer=self.producer)) + self._initialized = True def register_consumer(self, consumer: BaseConsumer) -> None: @@ -52,9 +62,52 @@ def register_consumer(self, consumer: BaseConsumer) -> None: topic_safe = sanitize_log_args(consumer.topic)[0] logger.info("Registered consumer for topic: '%s'", topic_safe) + async def _init_topics(self) -> None: + """Create required topics if they don't exist.""" + from aiokafka.admin import ( # type: ignore[import-untyped] + AIOKafkaAdminClient, + NewTopic, + ) + + from app.kafka.topics import TOPICS_TO_CREATE + + admin_client = AIOKafkaAdminClient( + bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS + ) + await admin_client.start() + try: + # DLQ topics for each required topic + standard topics + new_topics = [] + for topic in TOPICS_TO_CREATE: + new_topics.append( + NewTopic(name=topic, num_partitions=1, replication_factor=1) + ) + new_topics.append( + NewTopic( + name=f"dlq.{topic}", num_partitions=1, replication_factor=1 + ) + ) + + # Check existing topics + existing_topics = await admin_client.list_topics() + topics_to_create_metadata = [ + t for t in new_topics if t.name not in existing_topics + ] + + if topics_to_create_metadata: + topic_names = [t.name for t in topics_to_create_metadata] + logger.info("Creating missing Kafka topics: %s", topic_names) + await admin_client.create_topics(topics_to_create_metadata) + except Exception as e: + error_safe = sanitize_log_args(e)[0] + logger.warning("Failed to auto-create Kafka topics: %s", error_safe) + finally: + await admin_client.close() + async def start(self) -> None: """Start the producer, then all registered consumers.""" logger.info("Starting Kafka Manager...") + await self._init_topics() await self.producer.start() for consumer in self.consumers: diff --git a/app/kafka/topics.py b/app/kafka/topics.py index 4b16d17..137ee83 100644 --- a/app/kafka/topics.py +++ b/app/kafka/topics.py @@ -17,3 +17,14 @@ # Dead-letter topics DLQ_PREFIX: Final = "dlq." + +# All standard topics that should be auto-created on startup +TOPICS_TO_CREATE: Final = [ + NOTIFICATIONS_EMAIL, + MEDIA_UPLOAD, + MEDIA_PROCESS_RECORDING, + AUDIO_RAW, + AUDIO_SYNTHESIZED, + TEXT_ORIGINAL, + TEXT_TRANSLATED, +] diff --git a/app/main.py b/app/main.py index 0cabd80..4a43916 100644 --- a/app/main.py +++ b/app/main.py @@ -9,8 +9,10 @@ from app.core.config import settings from app.core.exception_handlers import register_exception_handlers +from app.core.init_admin import init_admin from app.core.rate_limiter import limiter, rate_limit_exception_handler from app.core.sanitize import sanitize_for_log +from app.db.session import SessionLocal from app.kafka.manager import get_kafka_manager from app.routers import api_router @@ -28,6 +30,14 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]: except Exception as exc: # Keep API startup alive in environments where Kafka isn't available (e.g. CI). logger.warning("Kafka startup skipped: %s", sanitize_for_log(exc)) + + # Initialize Admin + try: + with SessionLocal() as db_session: + init_admin(db_session) + except Exception as exc: + logger.warning("Admin initialization failed: %s", sanitize_for_log(exc)) + yield # Shutdown if kafka_started: diff --git a/app/meeting/constants.py b/app/meeting/constants.py index 897f2d8..7fafb54 100644 --- a/app/meeting/constants.py +++ b/app/meeting/constants.py @@ -22,6 +22,8 @@ class RoomStatus(enum.StrEnum): class ParticipantRole(enum.StrEnum): HOST = "host" + CO_HOST = "co_host" + PARTICIPANT = "participant" GUEST = "guest" diff --git a/app/meeting/repository.py b/app/meeting/repository.py index fd32ffd..4926218 100644 --- a/app/meeting/repository.py +++ b/app/meeting/repository.py @@ -3,7 +3,7 @@ import uuid from collections.abc import Sequence -from sqlalchemy import Row, and_, func, or_, select +from sqlalchemy import Row, and_, case, func, or_, select from sqlalchemy.orm import Session from app.meeting.constants import ParticipantRole, RoomStatus @@ -93,8 +93,18 @@ def get_meeting_history( Room.name, Room.created_at, Room.ended_at, - func.round( - func.extract("epoch", Room.ended_at - Room.created_at) / 60 + case( + ( + Room.ended_at.isnot(None), + func.round( + ( + func.julianday(Room.ended_at) + - func.julianday(Room.created_at) + ) + * 1440 + ), + ), + else_=None, ).label("duration_minutes"), func.count(Participant.id).label("participant_count"), # Subquery to get the requesting user's role in this room @@ -102,6 +112,7 @@ def get_meeting_history( .where( and_(Participant.room_id == Room.id, Participant.user_id == user_id) ) + .correlate(Room) .scalar_subquery() .label("role"), ) diff --git a/app/meeting/router.py b/app/meeting/router.py index e57f4db..e92f94f 100644 --- a/app/meeting/router.py +++ b/app/meeting/router.py @@ -48,7 +48,8 @@ def extract_guest_session(request: Request) -> str | None: ) if payload.get("type") == "guest": return payload.get("sub") # type: ignore[no-any-return] - except Exception: + except Exception as exc: + logger.error(f"Extract guest session error: {exc}") pass return None @@ -160,7 +161,7 @@ async def join_room( user=current_user, guest_session_id=guest_session_id, guest_name=payload.display_name if payload else None, - listening_language=payload.listening_language if payload else "en", + listening_language=payload.listening_language if payload else None, ) return JSONResponse( content={"status": "success", "message": MSG_ROOM_JOINED, "data": result}, diff --git a/app/meeting/schemas.py b/app/meeting/schemas.py index 88ca732..b83bc28 100644 --- a/app/meeting/schemas.py +++ b/app/meeting/schemas.py @@ -43,10 +43,15 @@ class InviteRequest(BaseModel): class JoinRoomRequest(BaseModel): display_name: str | None = Field( default=None, - description="Required for guests. " - "Authenticated users will use their account name.", + description=( + "Required for guests. Authenticated users will use their account name." + ), + ) + listening_language: str | None = Field( + default=None, + description="Language for receiving translations. " + "Falls back to user profile language if not set.", ) - listening_language: str = Field(default="en", description="Required for guests.") # ── Response schemas ────────────────────────────────────────────────── diff --git a/app/meeting/service.py b/app/meeting/service.py index 4d9cbe3..9f7edd8 100644 --- a/app/meeting/service.py +++ b/app/meeting/service.py @@ -176,11 +176,19 @@ def _validate_room_for_join(self, room: Room, user: User | None) -> bool: is_host = bool(user and (room.host_id == user.id)) - if room.scheduled_at and room.scheduled_at > utc_now() and not is_host: - raise BadRequestException( - code="MEETING_NOT_STARTED", - message="This meeting is scheduled for a future time.", + if room.scheduled_at: + # Normalize to naive UTC for comparison (SQLite strips tzinfo) + sched = ( + room.scheduled_at.replace(tzinfo=None) + if room.scheduled_at.tzinfo + else room.scheduled_at ) + now = utc_now().replace(tzinfo=None) + if sched > now and not is_host: + raise BadRequestException( + code="MEETING_NOT_STARTED", + message="This meeting is scheduled for a future time.", + ) if room.status == RoomStatus.PENDING.value: if is_host: @@ -248,6 +256,7 @@ async def _check_lobby_required( user: User | None, tracking_id: str, display_name: str, + listening_language: str | None, new_guest_token: str | None, live_pts: dict, ) -> dict | None: @@ -258,8 +267,10 @@ async def _check_lobby_required( ): raise BadRequestException( code="ROOM_FULL", - message=f"The room has reached its maximum capacity of {max_cap} " - f"participants.", + message=( + f"The room has reached its maximum" + f" capacity of {max_cap} participants." + ), ) lock_room = room.settings.get("lock_room", False) @@ -272,7 +283,15 @@ async def _check_lobby_required( if not requires_lobby: return None - await self.state.add_to_lobby(room_code, tracking_id, display_name) + # Priority: explicit join request > user profile > default "en" + if listening_language: + final_lang = listening_language + elif user and user.listening_language: + final_lang = user.listening_language + else: + final_lang = "en" + + await self.state.add_to_lobby(room_code, tracking_id, display_name, final_lang) res: dict = {"status": "waiting"} if new_guest_token: res["guest_token"] = new_guest_token @@ -287,7 +306,7 @@ async def _finalize_join( user: User | None, tracking_id: str, display_name: str, - listening_language: str, + listening_language: str | None, new_guest_token: str | None, ) -> dict: """Persist the participant record and add to Redis live state.""" @@ -304,7 +323,13 @@ async def _finalize_join( ) self.repo.create_participant(ptc) - final_lang = user.listening_language if user else listening_language + # Priority: explicit join request > user profile > default "en" + if listening_language: + final_lang = listening_language + elif user and user.listening_language: + final_lang = user.listening_language + else: + final_lang = "en" await self.state.add_participant( room_code=room_code, user_id=tracking_id, language=final_lang ) @@ -322,7 +347,7 @@ async def join_room( user: User | None = None, guest_session_id: str | None = None, guest_name: str | None = None, - listening_language: str = "en", + listening_language: str | None = None, ) -> dict: """Handle a user joining a room. @@ -352,6 +377,7 @@ async def join_room( user=user, tracking_id=tracking_id, display_name=display_name, + listening_language=listening_language, new_guest_token=new_guest_token, live_pts=live_pts, ) @@ -404,13 +430,7 @@ async def admit_user(self, host: User, room_code: str, target_user_id: str) -> N if not room or room.host_id != host.id: raise ForbiddenException(message="Only the host can admit participants.") - # We need the user's language to build their Redis presence. - # But this method only takes strings. Wait, since it's an internal system call, - # we realistically just want to move them in Redis. They will broadcast their - # real language on WS connect. For now, default to en. - was_in_lobby = await self.state.admit_from_lobby( - room_code, target_user_id, language="en" - ) + was_in_lobby = await self.state.admit_from_lobby(room_code, target_user_id) if not was_in_lobby: raise BadRequestException(message="User is not in the lobby.") diff --git a/app/meeting/state.py b/app/meeting/state.py index f324134..91a0688 100644 --- a/app/meeting/state.py +++ b/app/meeting/state.py @@ -70,11 +70,12 @@ async def get_participants(self, room_code: str) -> dict[str, dict]: # ── Lobby Set ──────────────────────────────────────────────────────── async def add_to_lobby( - self, room_code: str, user_id: str, display_name: str + self, room_code: str, user_id: str, display_name: str, language: str ) -> None: """Place a user in the waiting room/lobby hash.""" state = { "display_name": display_name, + "language": language, } await cast( "Awaitable[Any]", @@ -94,13 +95,20 @@ async def get_lobby(self, room_code: str) -> dict[str, dict]: ) return {uid: json.loads(val) for uid, val in raw_data.items()} - async def admit_from_lobby( - self, room_code: str, user_id: str, language: str - ) -> bool: + async def admit_from_lobby(self, room_code: str, user_id: str) -> bool: """Atomically remove a user from the lobby and add them to participants. Returns True if the user was actually in the lobby. """ + lobby_data_raw = await cast( + "Awaitable[Any]", self._redis.hget(key_room_lobby(room_code), user_id) + ) + if not lobby_data_raw: + return False + + lobby_state = json.loads(lobby_data_raw) + language = lobby_state.get("language", "en") + # A lightweight transaction (pipeline) to ensure we don't have partial state pipe = self._redis.pipeline() pipe.hdel(key_room_lobby(room_code), user_id) @@ -114,9 +122,8 @@ async def admit_from_lobby( name=key_room_participants(room_code), key=user_id, value=json.dumps(state) ) - results = await pipe.execute() - # results[0] is the result of srem. 1 means removed, 0 means wasn't there. - return bool(results[0]) + await pipe.execute() + return True # ── Active Speaker ─────────────────────────────────────────────────── diff --git a/app/meeting/ws_dependencies.py b/app/meeting/ws_dependencies.py new file mode 100644 index 0000000..94cd2b8 --- /dev/null +++ b/app/meeting/ws_dependencies.py @@ -0,0 +1,82 @@ +"""WebSocket-specific dependencies for authentication and authorization. + +WebSockets in the browser do not support sending custom headers easily. +Instead, we pass the JWT as a query parameter (`?token=...`). These +dependencies validate the token before the connection upgrade completes. +""" + +from fastapi import Depends, Query, WebSocketException, status +from jose import JWTError, jwt +from sqlalchemy import select +from sqlalchemy.orm import Session + +from app.auth.models import User +from app.core.config import settings +from app.db.session import get_db +from app.meeting.state import MeetingStateService + + +def authenticate_ws(token: str = Query(...), db: Session = Depends(get_db)) -> str: + """Validate the provided JWT token for a WebSocket connection. + + Works for both Authenticated Users (who present an access token) + and Guests (who present a guest token). + + Returns: + The user ID (UUID string) or guest session ID extracted from the token. + """ + error_exc = WebSocketException( + code=status.WS_1008_POLICY_VIOLATION, + reason="Invalid or missing authentication token", + ) + + try: + payload = jwt.decode( + token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] + ) + except JWTError as err: + raise error_exc from err + + raw_sub = payload.get("sub") + token_type = payload.get("type", "access") + + if ( + not raw_sub + or not isinstance(raw_sub, str) + or token_type not in ("access", "guest") + ): + raise error_exc + + if token_type == "access": + # The 'sub' is an email; we need the UUID to match Redis participant state + user = db.execute( + select(User).where(User.email == raw_sub) + ).scalar_one_or_none() + if not user: + raise error_exc + return str(user.id) + + return str(raw_sub) + + +async def assert_room_participant(room_code: str, user_id: str) -> dict: + """Ensure the user has successfully joined the room. + + Checks the Redis active participant list managed by MeetingStateService. + If the user has not called POST /meetings/{room}/join, they cannot + connect to the WebSockets. + + Returns: + The participant state dictionary (e.g. ``{"language": "en"}``). + """ + state_service = MeetingStateService() + participants = await state_service.get_participants(room_code) + + participant_state = participants.get(user_id) + if not participant_state: + raise WebSocketException( + code=status.WS_1008_POLICY_VIOLATION, + reason="User is not a participant of this room", + ) + + return participant_state diff --git a/app/meeting/ws_router.py b/app/meeting/ws_router.py new file mode 100644 index 0000000..213627e --- /dev/null +++ b/app/meeting/ws_router.py @@ -0,0 +1,354 @@ +"""WebSocket endpoints for real-time signaling, audio streaming, and captions.""" + +import asyncio +import base64 +import json +import logging +import time +from pathlib import Path + +from aiokafka import AIOKafkaConsumer # type: ignore[import-untyped] +from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect + +from app.core.config import settings +from app.core.sanitize import log_sanitizer +from app.kafka.topics import AUDIO_SYNTHESIZED, TEXT_ORIGINAL, TEXT_TRANSLATED +from app.meeting.state import MeetingStateService +from app.meeting.ws_dependencies import assert_room_participant, authenticate_ws +from app.schemas.pipeline import ( + SynthesizedAudioEvent, +) +from app.services.audio_bridge import get_audio_ingest_service +from app.services.connection_manager import get_connection_manager + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["websockets"]) + + +@router.websocket("/signaling/{room_code}") +async def signaling_websocket( + websocket: WebSocket, + room_code: str, + user_id: str = Depends(authenticate_ws), +) -> None: + """Relays WebRTC Offer, Answer, and ICE Candidate messages between peers. + + Includes `suppress_original` messages for muting source audio. + """ + try: + await assert_room_participant(room_code, user_id) + except Exception as e: + await websocket.close(code=1008, reason=str(e)) + return + + await websocket.accept() + + manager = get_connection_manager() + await manager.connect(room_code, user_id, websocket) + + try: + while True: + data = await websocket.receive_text() + try: + payload = json.loads(data) + target_user_id = payload.get("target_user_id") + + # If target specified, unicast. Otherwise, broadcast. + if target_user_id: + await manager.send_to_user(room_code, target_user_id, payload) + else: + await manager.broadcast_to_room( + room_code, payload, sender_id=user_id + ) + except json.JSONDecodeError: + logger.warning("Invalid JSON received on signaling WS") + + except WebSocketDisconnect: + manager.disconnect(room_code, user_id) + # Notify others that this peer left + await manager.broadcast_to_room( + room_code, {"type": "peer_left", "user_id": user_id}, sender_id=user_id + ) + + +@router.websocket("/audio/{room_code}") +async def audio_websocket( # noqa: C901 + websocket: WebSocket, + room_code: str, + user_id: str = Depends(authenticate_ws), +) -> None: + """Bidirectional audio stream. + + INGEST: Reads binary WebSocket frames -> Kafka ('audio.raw') + EGRESS: Kafka ('audio.synthesized') -> Binary WebSocket frames + """ + try: + participant_state = await assert_room_participant(room_code, user_id) + except Exception as e: + await websocket.close(code=1008, reason=str(e)) + return + + listening_language = participant_state.get("language", "en") + await websocket.accept() + print("Audio WS client connected: %s", user_id) + + ingest_svc = get_audio_ingest_service() + ingest_svc.reset_sequence(f"{room_code}:{user_id}") + + async def ingest_task() -> None: + """Reads WS binary frames (or Base64 text), packages, and sends to Kafka.""" + try: + while True: + message = await websocket.receive() + if message.get("text"): + try: + data = base64.b64decode(message["text"]) + except Exception: + logger.warning("Failed to decode base64 audio text frame.") + continue + elif "bytes" in message and message["bytes"] is not None: + data = message["bytes"] + else: + # Ignore close frames or other control messages here + continue + + # Chunk the data to avoid Kafka MessageSizeTooLargeError + # and to simulate standard continuous client streaming + chunk_size = 500 * 1024 # 500 KB per chunk safely under 1MB limit + + for i in range(0, len(data), chunk_size): + chunk = data[i : i + chunk_size] + await ingest_svc.publish_audio_chunk( + room_id=room_code, + user_id=user_id, + audio_bytes=chunk, + source_language=participant_state.get("language", "en"), + ) + except WebSocketDisconnect: + logger.info( + "Audio WS client disconnected: %s", log_sanitizer.sanitize(user_id) + ) + + # --- Shared event so egress consumer is ready before we start ingesting --- + egress_ready = asyncio.Event() + + async def egress_task() -> None: + """Reads Kafka synthesized audio, filters for user, writes to WS.""" + consumer = AIOKafkaConsumer( + AUDIO_SYNTHESIZED, + bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, + # No group_id → simple assign mode, avoids rebalance delays + auto_offset_reset="latest", + value_deserializer=lambda v: json.loads(v.decode("utf-8")), + enable_auto_commit=False, + ) + await consumer.start() + + # Force partition assignment by seeking to end + partitions = consumer.assignment() + if not partitions: + # Wait briefly for automatic assignment + await asyncio.sleep(1) + partitions = consumer.assignment() + for tp in partitions: + await consumer.seek_to_end(tp) + + logger.info( + "Egress consumer ready. Listening language=%s, partitions=%s", + listening_language, + partitions, + ) + print( + "Egress consumer ready. Listening language=%s, partitions=%s", + listening_language, + partitions, + ) + egress_ready.set() # Signal that we are ready to receive + + # Track the highest sequence seen to drop stale frames arriving out-of-order + highest_seq: dict[str, int] = {} + + try: + async for msg in consumer: + try: + event = SynthesizedAudioEvent.model_validate(msg.value) + payload = event.payload + + logger.info( + "Egress received: room=%s target_lang=%s" + " listening_lang=%s seq=%d", + payload.room_id, + payload.target_language, + listening_language, + payload.sequence_number, + ) + print( + "Egress received: room=%s" + " target_lang=%s listening_lang=%s seq=%d", + payload.room_id, + payload.target_language, + listening_language, + payload.sequence_number, + ) + + # Filter by Room + if payload.room_id != room_code: + print(f"Egress: skipping wrong room {payload.room_id}") + continue + + # Language filter: In production with multiple participants, + # only deliver audio matching the listener's language. + # For single-user testing, skip the filter so the speaker + # can hear their own translated audio. + participants = await MeetingStateService().get_participants( + room_code + ) + if ( + len(participants) > 1 + and payload.target_language != listening_language + ): + print( + "Egress: skipping lang mismatch" + f" target={payload.target_language} " + f"!= listening={listening_language}" + ) + continue + + # Stale frame guard (drop if more than 10 sequences behind latest) + speaker_key = payload.user_id + current_highest = highest_seq.get(speaker_key, -1) + + if payload.sequence_number < current_highest - 10: + logger.debug("Dropped stale audio frame from %s", speaker_key) + continue + + highest_seq[speaker_key] = max( + current_highest, payload.sequence_number + ) + + # Send to client (binary) + audio_bytes = base64.b64decode(payload.audio_data) + print(f"Egress: about to send {len(audio_bytes)} bytes to client") + + # Also save to disk for testing/validation + output_path = Path(rf"{settings.SYSTEM_PATH}\voiceai_output.raw") + mode = "ab" if payload.sequence_number > 0 else "wb" + + def _write_audio( + _path: Path = output_path, + _mode: str = mode, + _data: bytes = audio_bytes, + ) -> None: + with _path.open(_mode) as f: + f.write(_data) + + await asyncio.to_thread(_write_audio) + print( + f"Egress: SAVED {len(audio_bytes)} bytes to {output_path} " + f"(seq={payload.sequence_number})" + ) + + try: + await websocket.send_bytes(audio_bytes) + print( + "Egress: SUCCESSFULLY sent" + f" {len(audio_bytes)} bytes" + " via WebSocket" + ) + except Exception as send_err: + print( + "Egress: WebSocket send failed" + f" (but file was saved): {send_err}" + ) + + except Exception as frame_err: + print(f"Error processing egress frame: {frame_err}") + import traceback + + traceback.print_exc() + + finally: + await consumer.stop() + + async def guarded_ingest_task() -> None: + """Wait for egress consumer to be ready, then start ingesting.""" + await egress_ready.wait() + logger.info("Egress ready — starting audio ingest") + await ingest_task() + + task1 = asyncio.create_task(guarded_ingest_task()) + task2 = asyncio.create_task(egress_task()) + + try: + # Run until either task fails or disconnects + _done, pending = await asyncio.wait( + [task1, task2], return_when=asyncio.FIRST_COMPLETED + ) + # Cancel whatever is still running + for t in pending: + t.cancel() + except Exception: + pass + + +@router.websocket("/captions/{room_code}") +async def captions_websocket( + websocket: WebSocket, + room_code: str, + user_id: str = Depends(authenticate_ws), +) -> None: + """Broadcasts original and translated transcription events.""" + try: + # Validate they are in the room, but we don't strictly *need* their state + _ = await assert_room_participant(room_code, user_id) + except Exception as e: + await websocket.close(code=1008, reason=str(e)) + return + + await websocket.accept() + + # Use a persistent user-specific group so reconnects don't drop captions + # Note: "Subscribe from now" is handled via auto_offset_reset="latest" + # in their group creation or by wiping the group offsets. + # We'll use a dynamic timestamp group to force "latest". + consumer = AIOKafkaConsumer( + TEXT_ORIGINAL, + TEXT_TRANSLATED, + bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, + group_id=f"captions-{room_code}-{user_id}-{int(time.time())}", + auto_offset_reset="latest", + value_deserializer=lambda v: json.loads(v.decode("utf-8")), + ) + + await consumer.start() + + try: + async for msg in consumer: + payload_data = msg.value.get("payload", {}) + if payload_data.get("room_id") != room_code: + continue + + # Build unified caption response depending on topic + is_translation = msg.topic == TEXT_TRANSLATED + + caption_msg = { + "event": "caption", + "speaker_id": payload_data.get("user_id"), + "is_final": payload_data.get("is_final", True), + "timestamp_ms": int(time.time() * 1000), + } + + if is_translation: + caption_msg["language"] = payload_data.get("target_language") + caption_msg["text"] = payload_data.get("translated_text") + else: + caption_msg["language"] = payload_data.get("source_language") + caption_msg["text"] = payload_data.get("text") + + await websocket.send_json(caption_msg) + + except WebSocketDisconnect: + pass + finally: + await consumer.stop() diff --git a/app/routers/api.py b/app/routers/api.py index 2bccd23..65f7722 100644 --- a/app/routers/api.py +++ b/app/routers/api.py @@ -2,9 +2,11 @@ from app.auth.router import router as auth_router from app.meeting.router import router as meeting_router +from app.meeting.ws_router import router as ws_router from app.user.router import router as users_router api_router = APIRouter() api_router.include_router(auth_router) api_router.include_router(users_router) api_router.include_router(meeting_router, prefix="/meetings") +api_router.include_router(ws_router, prefix="/ws") diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py new file mode 100644 index 0000000..9181e02 --- /dev/null +++ b/app/schemas/__init__.py @@ -0,0 +1 @@ +# Pipeline schemas package diff --git a/app/schemas/pipeline.py b/app/schemas/pipeline.py new file mode 100644 index 0000000..54bdd63 --- /dev/null +++ b/app/schemas/pipeline.py @@ -0,0 +1,121 @@ +"""Pydantic event schemas for the real-time audio processing pipeline. + +Each schema represents one stage of the pipeline: + audio.raw → text.original → text.translated → audio.synthesized + +All audio payloads use base64 encoding for compatibility with +the existing JSON-based Kafka serializer. +""" + +from enum import Enum + +from pydantic import BaseModel, Field + +from app.kafka.schemas import BaseEvent + +# ── Audio Encoding Enum ────────────────────────────────────────────── + + +class AudioEncoding(str, Enum): # noqa: UP042 + """Supported audio encoding formats throughout the pipeline.""" + + LINEAR16 = "linear16" # PCM 16-bit signed, little-endian + OPUS = "opus" + + +# ── Stage 1: Raw Audio Ingest ──────────────────────────────────────── + + +class AudioChunkPayload(BaseModel): + """Payload for a single audio chunk from a WebSocket client.""" + + room_id: str = Field(..., description="Room the audio originates from.") + user_id: str = Field( + ..., description="Speaker's tracking ID (user UUID or guest session UUID)." + ) + sequence_number: int = Field( + ..., ge=0, description="Monotonically increasing chunk index." + ) + audio_data: str = Field(..., description="Base64-encoded raw audio bytes.") + sample_rate: int = Field(default=16000, description="Audio sample rate in Hz.") + encoding: AudioEncoding = Field( + default=AudioEncoding.LINEAR16, description="Audio encoding format." + ) + source_language: str = Field( + default="en", description="Speaker's language (ISO 639-1)." + ) + + +class AudioChunkEvent(BaseEvent[AudioChunkPayload]): + """Kafka event wrapping a raw audio chunk for the STT stage.""" + + event_type: str = "audio.chunk" + + +# ── Stage 2: Transcribed Text ──────────────────────────────────────── + + +class TranscriptionPayload(BaseModel): + """Payload produced by the STT worker.""" + + room_id: str + user_id: str + sequence_number: int = Field(..., ge=0) + text: str = Field(..., description="Transcribed text from the audio chunk.") + source_language: str = Field( + ..., description="Detected or declared source language." + ) + is_final: bool = Field( + default=True, description="Whether this is a final transcription or interim." + ) + confidence: float = Field( + default=0.0, ge=0.0, le=1.0, description="STT confidence score." + ) + + +class TranscriptionEvent(BaseEvent[TranscriptionPayload]): + """Kafka event wrapping a transcription result for the Translation stage.""" + + event_type: str = "text.transcription" + + +# ── Stage 3: Translated Text ──────────────────────────────────────── + + +class TranslationPayload(BaseModel): + """Payload produced by the Translation worker.""" + + room_id: str + user_id: str + sequence_number: int = Field(..., ge=0) + original_text: str + translated_text: str + source_language: str + target_language: str + + +class TranslationEvent(BaseEvent[TranslationPayload]): + """Kafka event wrapping a translation result for the TTS stage.""" + + event_type: str = "text.translation" + + +# ── Stage 4: Synthesized Audio ─────────────────────────────────────── + + +class SynthesizedAudioPayload(BaseModel): + """Payload produced by the TTS worker.""" + + room_id: str + user_id: str + sequence_number: int = Field(..., ge=0) + audio_data: str = Field(..., description="Base64-encoded synthesized audio bytes.") + target_language: str + sample_rate: int = Field(default=16000) + encoding: AudioEncoding = Field(default=AudioEncoding.LINEAR16) + + +class SynthesizedAudioEvent(BaseEvent[SynthesizedAudioPayload]): + """Kafka event wrapping synthesized audio for egress to WebSocket clients.""" + + event_type: str = "audio.synthesized" diff --git a/app/services/audio_bridge.py b/app/services/audio_bridge.py new file mode 100644 index 0000000..99c37c1 --- /dev/null +++ b/app/services/audio_bridge.py @@ -0,0 +1,102 @@ +"""Audio bridge: Ingest (WebSocket → Kafka) and Egress (Kafka → WebSocket). + +The ``AudioIngestService`` accepts raw audio bytes from a WebSocket handler, +wraps them in an ``AudioChunkEvent``, and publishes to the ``audio.raw`` topic. + +The ``AudioEgressRouter`` is a Kafka consumer that reads from +``audio.synthesized`` and routes the synthesized audio back to the correct +room's WebSocket connections. +""" + +import base64 +import logging + +from app.core.sanitize import log_sanitizer +from app.kafka.manager import get_kafka_manager +from app.kafka.topics import AUDIO_RAW +from app.schemas.pipeline import ( + AudioChunkEvent, + AudioChunkPayload, + AudioEncoding, +) + +logger = logging.getLogger(__name__) + + +class AudioIngestService: + """Publishes raw audio chunks from WebSocket clients to Kafka. + + Used by the WebSocket handler to feed data into the processing pipeline. + Each chunk is keyed by ``room_id`` for Kafka partition-level ordering. + """ + + def __init__(self) -> None: + self._sequence_counters: dict[str, int] = {} + + def _next_sequence(self, user_key: str) -> int: + """Return a monotonically increasing sequence number per user.""" + current = self._sequence_counters.get(user_key, -1) + current += 1 + self._sequence_counters[user_key] = current + return current + + def reset_sequence(self, user_key: str) -> None: + """Reset the sequence counter when a user disconnects.""" + self._sequence_counters.pop(user_key, None) + + async def publish_audio_chunk( + self, + *, + room_id: str, + user_id: str, + audio_bytes: bytes, + source_language: str = "en", + sample_rate: int = 16000, + encoding: str = "linear16", + ) -> None: + """Encode and publish an audio chunk to the ``audio.raw`` topic. + + Args: + room_id: The meeting room code. + user_id: Speaker's tracking ID. + audio_bytes: Raw audio data (PCM or Opus). + source_language: Speaker's language code. + sample_rate: Audio sample rate in Hz. + encoding: Audio encoding format. + """ + user_key = f"{room_id}:{user_id}" + seq = self._next_sequence(user_key) + + audio_b64 = base64.b64encode(audio_bytes).decode("ascii") + + payload = AudioChunkPayload( + room_id=room_id, + user_id=user_id, + sequence_number=seq, + audio_data=audio_b64, + sample_rate=sample_rate, + encoding=AudioEncoding(encoding), + source_language=source_language, + ) + event = AudioChunkEvent(payload=payload) + + kafka = get_kafka_manager() + await kafka.producer.send(AUDIO_RAW, event, key=room_id) + + logger.debug( + "Published audio chunk seq=%d for user=%s in room=%s", + seq, + log_sanitizer.sanitize(user_id), + log_sanitizer.sanitize(room_id), + ) + + +# ── Module-level singleton ──────────────────────────────────────────── +_ingest_service: AudioIngestService | None = None + + +def get_audio_ingest_service() -> AudioIngestService: + global _ingest_service # noqa: PLW0603 + if _ingest_service is None: + _ingest_service = AudioIngestService() + return _ingest_service diff --git a/app/services/connection_manager.py b/app/services/connection_manager.py new file mode 100644 index 0000000..e924484 --- /dev/null +++ b/app/services/connection_manager.py @@ -0,0 +1,157 @@ +"""WebSocket Connection Manager with Redis Pub/Sub backplane. + +Manages active WebSocket connections per room and user. +Uses Redis Pub/Sub to allow broadcasting messages across multiple +application instances. + +For example, if User A (connected to Pod 1) sends a signaling message +to Room X, it's published to the Redis channel for Room X. Pod 2 receives +it and sends it to User B's WebSocket. +""" + +import asyncio +import json +import logging + +from fastapi import WebSocket +from redis.asyncio import Redis + +from app.core.sanitize import log_sanitizer + +logger = logging.getLogger(__name__) + + +class ConnectionManager: + """Manages WebSocket connections and multi-instance Pub/Sub scaling.""" + + def __init__(self, redis_client: Redis) -> None: + # Maps room_code -> { user_id -> WebSocket } + self.active_connections: dict[str, dict[str, WebSocket]] = {} + # Maps room_code -> BackgroundTask (Redis subscriber) + self._pubsub_tasks: dict[str, asyncio.Task] = {} + self.redis = redis_client + + async def connect(self, room_code: str, user_id: str, websocket: WebSocket) -> None: + """Register an accepted WebSocket connection in the manager.""" + if room_code not in self.active_connections: + self.active_connections[room_code] = {} + # Start pub/sub listener for the room + self._start_listening(room_code) + + self.active_connections[room_code][user_id] = websocket + logger.info( + "User %s connected to room %s", + log_sanitizer.sanitize(user_id), + log_sanitizer.sanitize(room_code), + ) + + def disconnect(self, room_code: str, user_id: str) -> None: + """Remove a WebSocket connection from the manager.""" + if room_code in self.active_connections: + self.active_connections[room_code].pop(user_id, None) + logger.info( + "User %s disconnected from room %s", + log_sanitizer.sanitize(user_id), + log_sanitizer.sanitize(room_code), + ) + + # Clean up empty rooms + if not self.active_connections[room_code]: + del self.active_connections[room_code] + self._stop_listening(room_code) + + async def broadcast_to_room( + self, room_code: str, message: dict, sender_id: str | None = None + ) -> None: + """Publish a message to all users in a room across all instances.""" + payload = {"type": "broadcast", "sender_id": sender_id, "data": message} + await self.redis.publish(self._get_channel_name(room_code), json.dumps(payload)) + + async def send_to_user( + self, room_code: str, target_user_id: str, message: dict + ) -> None: + """Publish a message to a specific user in a room across all instances.""" + payload = {"type": "unicast", "target_user_id": target_user_id, "data": message} + await self.redis.publish(self._get_channel_name(room_code), json.dumps(payload)) + + # ── Internal Redis Pub/Sub Logic ───────────────────────────────── + + def _get_channel_name(self, room_code: str) -> str: + return f"ws:room:{room_code}" + + def _start_listening(self, room_code: str) -> None: + """Start a background task to listen for room messages on Redis.""" + if room_code not in self._pubsub_tasks: + task = asyncio.create_task(self._listen_to_redis(room_code)) + self._pubsub_tasks[room_code] = task + + def _stop_listening(self, room_code: str) -> None: + """Cancel the background task listening for room messages.""" + task = self._pubsub_tasks.pop(room_code, None) + if task and not task.done(): + task.cancel() + + async def _listen_to_redis(self, room_code: str) -> None: # noqa: C901 + """Listen to a Redis channel and dispatch to local websockets.""" + pubsub = self.redis.pubsub() + channel = self._get_channel_name(room_code) + await pubsub.subscribe(channel) + + try: + async for message in pubsub.listen(): + if message["type"] != "message": + continue + + payload = json.loads(message["data"]) + msg_type = payload.get("type") + data = payload.get("data") + + # Check if room is still active locally + if room_code not in self.active_connections: + break + + if msg_type == "broadcast": + sender_id = payload.get("sender_id") + for user_id, ws in list(self.active_connections[room_code].items()): + # Don't echo back to the sender + if user_id != sender_id: + try: + await ws.send_json(data) + except Exception: + logger.warning( + "Failed to send message to %s", + log_sanitizer.sanitize(user_id), + ) + + elif msg_type == "unicast": + target_id = payload.get("target_user_id") + target_ws = self.active_connections[room_code].get(target_id) + if target_ws: + try: + await target_ws.send_json(data) + except Exception: + logger.warning( + "Failed to send unicast message to %s", + log_sanitizer.sanitize(target_id), + ) + except asyncio.CancelledError: + pass + finally: + await pubsub.unsubscribe(channel) + + +# ── Module-level Dependency ─────────────────────────────────────────── + +from app.auth.token_store import _get_redis_client # noqa: E402 + +# We keep a singleton reference for the application lifecycle +_connection_manager: ConnectionManager | None = None + + +def get_connection_manager() -> ConnectionManager: + global _connection_manager # noqa: PLW0603 + if _connection_manager is None: + # Create it synchronously but pass the global Redis client + redis_client = _get_redis_client() + _connection_manager = ConnectionManager(redis_client) + return _connection_manager diff --git a/app/services/stt_worker.py b/app/services/stt_worker.py new file mode 100644 index 0000000..513762a --- /dev/null +++ b/app/services/stt_worker.py @@ -0,0 +1,110 @@ +"""STT (Speech-to-Text) Kafka consumer worker. + +Consumes raw audio chunks from ``audio.raw``, calls the Deepgram STT API, +and publishes transcription results to ``text.original``. +""" + +import base64 +import logging +import time +from typing import Any + +from app.external_services.deepgram.service import get_deepgram_stt_service +from app.kafka.consumer import BaseConsumer +from app.kafka.schemas import BaseEvent +from app.kafka.topics import AUDIO_RAW, TEXT_ORIGINAL +from app.schemas.pipeline import ( + AudioChunkEvent, + TranscriptionEvent, + TranscriptionPayload, +) + +logger = logging.getLogger(__name__) + + +class STTWorker(BaseConsumer): + """Kafka consumer that transcribes audio chunks via Deepgram. + + Subscribes to ``audio.raw`` and publishes ``TranscriptionEvent`` + messages to ``text.original``. + """ + + topic = AUDIO_RAW + group_id = "stt-worker-group" + event_schema = AudioChunkEvent + + async def handle(self, event: BaseEvent[Any]) -> None: + """Process a single audio chunk: decode → STT → publish transcript.""" + chunk_event = AudioChunkEvent.model_validate(event.model_dump()) + payload = chunk_event.payload + + pipeline_start = time.monotonic() + + # 1. Decode base64 audio + audio_bytes = base64.b64decode(payload.audio_data) + + if not audio_bytes: + logger.warning( + "Empty audio chunk seq=%d from user=%s, skipping", + payload.sequence_number, + payload.user_id, + ) + return + + # 2. Call Deepgram STT (or Mock it if no API Key provided) + from app.core.config import settings + + if not settings.DEEPGRAM_API_KEY: + logger.info("DEEPGRAM_API_KEY not set. Mocking STT response for testing.") + result: dict[str, Any] = { + "text": ( + "Hello, this is a simulated transcription for testing purposes." + ), + "detected_language": payload.source_language, + "confidence": 1.0, + } + else: + stt_service = get_deepgram_stt_service() + result = await stt_service.transcribe( + audio_bytes, + language=payload.source_language, + sample_rate=payload.sample_rate, + encoding=payload.encoding.value, + ) + + text = result.get("text", "").strip() + if not text: + logger.debug( + "No speech detected in chunk seq=%d from user=%s", + payload.sequence_number, + payload.user_id, + ) + return + + # 3. Build and publish transcription event + transcription_payload = TranscriptionPayload( + room_id=payload.room_id, + user_id=payload.user_id, + sequence_number=payload.sequence_number, + text=text, + source_language=result.get("detected_language", payload.source_language), + is_final=True, + confidence=result.get("confidence", 0.0), + ) + transcription_event = TranscriptionEvent(payload=transcription_payload) + + await self._producer.send( + TEXT_ORIGINAL, transcription_event, key=payload.room_id + ) + + # 4. Log pipeline latency + elapsed_ms = (time.monotonic() - pipeline_start) * 1000 + logger.info( + "STT: seq=%d room=%s user=%s text='%s' confidence=%.2f latency=%.1fms", + payload.sequence_number, + payload.room_id, + payload.user_id, + text[:50], + result.get("confidence", 0.0), + elapsed_ms, + ) diff --git a/app/services/translation_worker.py b/app/services/translation_worker.py new file mode 100644 index 0000000..287d657 --- /dev/null +++ b/app/services/translation_worker.py @@ -0,0 +1,177 @@ +"""Translation Kafka consumer worker. + +Consumes transcribed text from ``text.original``, determines the target +languages from the room's participant state in Redis, calls the DeepL API +(with OpenAI GPT fallback), and publishes one ``TranslationEvent`` per +target language to ``text.translated``. +""" + +import logging +import time +from typing import Any + +from app.external_services.deepl.service import ( + get_deepl_translation_service, + get_openai_translation_fallback, +) +from app.kafka.consumer import BaseConsumer +from app.kafka.schemas import BaseEvent +from app.kafka.topics import TEXT_ORIGINAL, TEXT_TRANSLATED +from app.meeting.state import MeetingStateService +from app.schemas.pipeline import ( + TranscriptionEvent, + TranslationEvent, + TranslationPayload, +) + +logger = logging.getLogger(__name__) + + +class TranslationWorker(BaseConsumer): + """Kafka consumer that translates transcribed text for each listener. + + Subscribes to ``text.original`` and publishes ``TranslationEvent`` + messages to ``text.translated`` — one per unique target language + needed in the room. + """ + + topic = TEXT_ORIGINAL + group_id = "translation-worker-group" + event_schema = TranscriptionEvent + + def __init__(self, producer: object) -> None: + super().__init__(producer=producer) + self._state = MeetingStateService() + + async def handle(self, event: BaseEvent[Any]) -> None: + """Process a transcription: resolve target languages → translate → publish.""" + tx_event = TranscriptionEvent.model_validate(event.model_dump()) + payload = tx_event.payload + + pipeline_start = time.monotonic() + + # Skip interim transcriptions — only process final results + if not payload.is_final: + return + + # 1. Determine target languages from room participants + participants = await self._state.get_participants(payload.room_id) + target_languages = { + state.get("language", "en") + for state in participants.values() + if state.get("language", "en") != payload.source_language + } + + if not target_languages: + logger.debug( + "No translation needed for seq=%d in room=%s (all same language)", + payload.sequence_number, + payload.room_id, + ) + return + + # 2. Translate for each target language + for target_lang in target_languages: + try: + translated_text = await self._translate_text( + payload.text, + source_language=payload.source_language, + target_language=target_lang, + ) + + if not translated_text: + logger.warning( + "Empty translation for seq=%d target=%s", + payload.sequence_number, + target_lang, + ) + continue + + # 3. Publish translation event + translation_payload = TranslationPayload( + room_id=payload.room_id, + user_id=payload.user_id, + sequence_number=payload.sequence_number, + original_text=payload.text, + translated_text=translated_text, + source_language=payload.source_language, + target_language=target_lang, + ) + translation_event = TranslationEvent(payload=translation_payload) + + await self._producer.send( + TEXT_TRANSLATED, translation_event, key=payload.room_id + ) + + logger.debug( + "Translation: seq=%d %s→%s text='%s'", + payload.sequence_number, + payload.source_language, + target_lang, + translated_text[:50], + ) + + except Exception: + logger.exception( + "Translation failed for seq=%d target=%s", + payload.sequence_number, + target_lang, + ) + raise + + elapsed_ms = (time.monotonic() - pipeline_start) * 1000 + logger.info( + "Translation: seq=%d room=%s targets=%s latency=%.1fms", + payload.sequence_number, + payload.room_id, + sorted(target_languages), + elapsed_ms, + ) + + async def _translate_text( + self, + text: str, + *, + source_language: str, + target_language: str, + ) -> str: + """Dispatch translation to DeepL, OpenAI fallback, or mock. + + Returns the translated text string, or empty string on failure. + """ + from app.core.config import settings + + if not settings.DEEPL_API_KEY and not settings.OPENAI_API_KEY: + logger.info("Translation config missing. Mocking text for testing.") + return f"[Mocked Translation -> {target_language}]: {text}" + + deepl = get_deepl_translation_service() + openai_fallback = get_openai_translation_fallback() + + try: + if settings.DEEPL_API_KEY and deepl.supports_language(target_language): + result = await deepl.translate( + text, + source_language=source_language, + target_language=target_language, + ) + elif settings.OPENAI_API_KEY: + logger.info( + "DeepL skipped or unsupported for '%s', falling back to OpenAI", + target_language, + ) + result = await openai_fallback.translate( + text, + source_language=source_language, + target_language=target_language, + ) + else: + raise RuntimeError("No available translation backend.") + except Exception as api_exc: + logger.warning( + "Translation backend failed (%s). Mocking translation.", + str(api_exc), + ) + return f"[Mocked Translation -> {target_language}]: {text}" + + return str(result.get("translated_text", "")) diff --git a/app/services/tts_worker.py b/app/services/tts_worker.py new file mode 100644 index 0000000..19d81d6 --- /dev/null +++ b/app/services/tts_worker.py @@ -0,0 +1,115 @@ +"""TTS (Text-to-Speech) Kafka consumer worker. + +Consumes translated text from ``text.translated``, calls the configured +TTS provider (OpenAI or Voice.ai), and publishes synthesized audio +to ``audio.synthesized``. + +The active provider is controlled by ``settings.ACTIVE_TTS_PROVIDER``. +""" + +import base64 +import logging +import time +from typing import Any + +from app.core.config import settings +from app.external_services.openai_tts.service import get_openai_tts_service +from app.external_services.voiceai.service import get_voiceai_tts_service +from app.kafka.consumer import BaseConsumer +from app.kafka.schemas import BaseEvent +from app.kafka.topics import AUDIO_SYNTHESIZED, TEXT_TRANSLATED +from app.schemas.pipeline import ( + AudioEncoding, + SynthesizedAudioEvent, + SynthesizedAudioPayload, + TranslationEvent, +) + +logger = logging.getLogger(__name__) + + +class TTSWorker(BaseConsumer): + """Kafka consumer that synthesizes translated text into audio. + + Subscribes to ``text.translated`` and publishes + ``SynthesizedAudioEvent`` messages to ``audio.synthesized``. + + Supports two providers (switchable via ``ACTIVE_TTS_PROVIDER``): + - ``"openai"`` — OpenAI TTS (tts-1) + - ``"voiceai"`` — Voice.ai TTS (voiceai-tts-multilingual-v1-latest) + """ + + topic = TEXT_TRANSLATED + group_id = "tts-worker-group" + event_schema = TranslationEvent + + async def handle(self, event: BaseEvent[Any]) -> None: + """Process a translation: synthesize audio → publish.""" + tl_event = TranslationEvent.model_validate(event.model_dump()) + payload = tl_event.payload + + pipeline_start = time.monotonic() + + text = payload.translated_text.strip() + if not text: + logger.warning( + "Empty translated text for seq=%d, skipping TTS", + payload.sequence_number, + ) + return + + # 1. Call the configured TTS provider + encoding = settings.PIPELINE_AUDIO_ENCODING + audio_result = await self._synthesize( + text=text, + language=payload.target_language, + encoding=encoding, + ) + + audio_bytes = audio_result["audio_bytes"] + sample_rate = audio_result["sample_rate"] + + # 2. Base64 encode for Kafka transport + audio_b64 = base64.b64encode(audio_bytes).decode("ascii") + + # 3. Build and publish synthesized audio event + synth_payload = SynthesizedAudioPayload( + room_id=payload.room_id, + user_id=payload.user_id, + sequence_number=payload.sequence_number, + audio_data=audio_b64, + target_language=payload.target_language, + sample_rate=sample_rate, + encoding=AudioEncoding(encoding), + ) + synth_event = SynthesizedAudioEvent(payload=synth_payload) + + await self._producer.send(AUDIO_SYNTHESIZED, synth_event, key=payload.room_id) + + # 4. Log pipeline latency + elapsed_ms = (time.monotonic() - pipeline_start) * 1000 + logger.info( + "TTS: seq=%d room=%s lang=%s provider=%s audio_size=%d latency=%.1fms", + payload.sequence_number, + payload.room_id, + payload.target_language, + settings.ACTIVE_TTS_PROVIDER, + len(audio_bytes), + elapsed_ms, + ) + + async def _synthesize(self, *, text: str, language: str, encoding: str) -> dict: + """Dispatch to the active TTS provider. + + Returns: + A dict with ``audio_bytes`` and ``sample_rate``. + """ + provider = settings.ACTIVE_TTS_PROVIDER.lower() + + if provider == "voiceai": + return await get_voiceai_tts_service().synthesize( + text, language=language, encoding=encoding + ) + + # Default: OpenAI + return await get_openai_tts_service().synthesize(text, encoding=encoding) diff --git a/app/user/schemas.py b/app/user/schemas.py index d292ffd..226ca54 100644 --- a/app/user/schemas.py +++ b/app/user/schemas.py @@ -38,6 +38,7 @@ class UserProfileResponse(BaseModel): listening_language: str is_active: bool is_verified: bool + user_role: str created_at: datetime model_config = ConfigDict(from_attributes=True) diff --git a/docs/testing.md b/docs/testing.md new file mode 100644 index 0000000..f8e6473 --- /dev/null +++ b/docs/testing.md @@ -0,0 +1,86 @@ +# Testing FluentMeet WebSockets via Postman + +Because FluentMeet's real-time features rely on WebSockets, you can test the entire pipeline end-to-end using Postman before wiring up the frontend SDK. + +## Prerequisites + +1. Ensure the FluentMeet backend is running (`uvicorn app.main:app --reload`). +2. Ensure Kafka and Redis are running locally. +3. Ensure the Kafka Consumers (STT, translation, TTS) are running in the background. + +## 1. Obtain Authentication Token & Room Code + +First, create a meeting and join it to get an authentication token. You must actually "join" the room so that your participant state is set in Redis. + +1. **REST Request**: `POST {{base_url}}/api/v1/meetings` (creates a room, returns `room_code`) +2. **REST Request**: `POST {{base_url}}/api/v1/meetings/{{room_code}}/join` + - **Body**: `{ "listening_language": "es", "display_name": "Test User" }` +3. Extract your Bearer Token (either from the Guest token response or your registered user Login). For WebSockets, we will append it as a Query Parameter: `?token=YOUR_TOKEN`. + +--- + +## 2. Test Signaling WebSocket + +The Signaling WebSocket behaves like a Pub/Sub layer for WebRTC negotiation. + +**Postman Setup**: +1. Click **New** -> **WebSocket**. +2. **URL**: `ws://localhost:8000/api/v1/ws/signaling/{{room_code}}?token={{token}}` +3. Click **Connect**. + +**Actions to Test**: +1. Sending a broadcast (e.g. an Offer but no target ID): + - In the Message box, write: `{"type": "offer", "sdp": "fake_sdp"}` + - Click **Send**. + - (You won't get it back because the server filters out messages from the sender. If you connect a *second* Postman tab with a different token/user, the second tab will receive it). + +2. Unicasting (suppress original audio): + - Send `{"type": "suppress_original", "target_user_id": "other-user-uuid"}` + - The server uses Redis to route this exactly to that user. + +--- + +## 3. Test Captions WebSocket + +The Captions WebSocket is unidirectional. It receives events dynamically from Kafka. + +**Postman Setup**: +1. Click **New** -> **WebSocket**. +2. **URL**: `ws://localhost:8000/api/v1/ws/captions/{{room_code}}?token={{token}}` +3. Click **Connect**. + +**Actions to Test**: +1. Keep this connection open. You will not send anything into it. +2. When audio is sent to the Audio WebSocket (below), the AI pipeline will trigger `Transcriptions` and `Translations` to Kafka. +3. You will see JSON arrive here automatically containing the text payload: + ```json + { + "event": "caption", + "speaker_id": "...", + "language": "es", + "text": "Hola mundo", + "is_final": true, + "timestamp_ms": 1712123456789 + } + ``` + +--- + +## 4. Test Audio WebSocket (Bidirectional) + +The Audio WebSocket requires broadcasting Binary streams. + +**Postman Setup**: +1. Click **New** -> **WebSocket**. +2. **URL**: `ws://localhost:8000/api/v1/ws/audio/{{room_code}}?token={{token}}` +3. Click **Connect**. + +**Actions to Test**: +In Postman, WebSockets text messages represent Strings, but you must send **Binary** messages for audio payloads. + +1. Generate a raw 16kHz PCM audio chunk file on your computer. +2. In the Postman WebSocket interface, next to the "Message" input field, choose **"Base64"** or **"Binary"** file input type. +3. Select your raw audio file and click **Send**. +4. The Backend will package it into `audio.raw`. It will cascade through `STTWorker` -> `TranslationWorker` -> `TTSWorker`. +5. Shortly after, your Postman Audio WebSocket will receive a **Binary frame**. This is the translated synthesized audio stream returned by Kafka! +6. If you have the Captions WebSocket still open in another tab, you will see the captions flash simultaneously. diff --git a/pyproject.toml b/pyproject.toml index a32dac6..f47f621 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,3 +100,7 @@ module = [ "resend.*" ] ignore_missing_imports = true + +[tool.pytest.ini_options] +testpaths = ["tests"] +asyncio_mode = "auto" diff --git a/requirements.txt b/requirements.txt index 0f04918..45ac4f4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -63,10 +63,10 @@ openai==2.26.0 packaging==26.0 passlib==1.7.4 pathspec==1.0.4 -psycopg2-binary==2.9.11 platformdirs==4.9.4 pluggy==1.6.0 propcache==0.4.1 +psycopg2-binary==2.9.11 pyasn1==0.6.2 pycparser==3.0 pydantic==2.12.5 @@ -83,7 +83,7 @@ python-jose==3.5.0 python-multipart==0.0.22 pytokens==0.4.1 PyYAML==6.0.3 -redis==7.3.0 +redis==7.4.0 requests==2.32.5 resend==2.23.0 rich==14.3.3 diff --git a/scripts/introduction.raw b/scripts/introduction.raw new file mode 100644 index 0000000..5e42198 Binary files /dev/null and b/scripts/introduction.raw differ diff --git a/scripts/output.raw b/scripts/output.raw new file mode 100644 index 0000000..edf4397 Binary files /dev/null and b/scripts/output.raw differ diff --git a/scripts/test_audio_client.py b/scripts/test_audio_client.py new file mode 100644 index 0000000..6c5db28 --- /dev/null +++ b/scripts/test_audio_client.py @@ -0,0 +1,106 @@ +import asyncio +import base64 +from pathlib import Path + +import websockets + +from app.core.config import settings + +""" +Run this script with python -m scripts.test_audio_client +""" + +# Configuration - Update these if needed +ROOM_CODE = f"{settings.ROOM_CODE}" +# NOTE: Replace 'YOUR_ACCESS_TOKEN' with the JWT token from Postman +TOKEN = f"{settings.ACCESS_TOKEN}" + +WS_URL = f"ws://localhost:8000/api/v1/ws/audio/{ROOM_CODE}?token={TOKEN}" +INPUT_FILE = Path("scripts/introduction.raw") +OUTPUT_FILE = Path("scripts/voiceai_output.raw") + +TIMEOUT_SECONDS = 120 # Max wait for the full pipeline to respond + + +async def run_audio_test(): + print(f"Connecting to {WS_URL[:80]}...") + try: + async with websockets.connect( + WS_URL, + max_size=10 * 1024 * 1024, # Allow up to 10MB messages + ping_interval=30, + ping_timeout=60, + ) as websocket: + print("Connected!") + + # Read local raw file + try: + audio_data = await asyncio.to_thread(INPUT_FILE.read_bytes) + print(f"Read {len(audio_data)} bytes from {INPUT_FILE}") + except FileNotFoundError: + print(f"Error: Could not find {INPUT_FILE}. Make sure it exists!") + return + + # Send as base64 text (which our backend now supports!) + b64_data = base64.b64encode(audio_data).decode("utf-8") + print(f"Sending {len(b64_data)} bytes of base64 data...") + await websocket.send(b64_data) + print(f"Sent! Waiting up to {TIMEOUT_SECONDS}s for pipeline response...") + + # Collect all received audio chunks + received_chunks = [] + chunk_count = 0 + + try: + while True: + response = await asyncio.wait_for( + websocket.recv(), timeout=TIMEOUT_SECONDS + ) + + if isinstance(response, bytes): + chunk_count += 1 + received_chunks.append(response) + print( + f" Received audio chunk #{chunk_count}:" + f" {len(response)} bytes" + ) + else: + print(f" Received text message: {response[:200]}") + + except TimeoutError: + if received_chunks: + print(f"\nTimeout reached. Collected {chunk_count} chunks total.") + else: + print("\nTimeout reached. No audio data received from pipeline.") + print( + "Check server console for" + " 'Egress: SUCCESSFULLY sent'" + " or 'FAILED' messages." + ) + return + + except websockets.exceptions.ConnectionClosed as cc: + print(f"\nConnection closed by server: {cc}") + if not received_chunks: + return + + # Save all collected chunks to file + if received_chunks: + all_audio = b"".join(received_chunks) + await asyncio.to_thread(OUTPUT_FILE.write_bytes, all_audio) + print(f"\nSUCCESS! Saved {len(all_audio)} bytes to '{OUTPUT_FILE}'") + print( + "To play: ffplay -f s16le" + " -sample_rate 16000" + f" -ch_layout mono -i {OUTPUT_FILE}" + ) + + except Exception as e: + print(f"Connection error: {type(e).__name__}: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + asyncio.run(run_audio_test()) diff --git a/scripts/voiceai_output.raw b/scripts/voiceai_output.raw new file mode 100644 index 0000000..9ca93ee Binary files /dev/null and b/scripts/voiceai_output.raw differ diff --git a/tests/meeting/test_meeting_router.py b/tests/meeting/test_meeting_router.py new file mode 100644 index 0000000..db49efa --- /dev/null +++ b/tests/meeting/test_meeting_router.py @@ -0,0 +1,730 @@ +"""Integration tests for the meeting API router. + +Uses an in-memory SQLite database, FakeRedis, and the FastAPI TestClient +to exercise full request → response cycles through ``/api/v1/meetings``. +""" + +from collections.abc import Generator +from unittest.mock import AsyncMock + +import httpx +import pytest +import pytest_asyncio +from httpx import ASGITransport +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.pool import StaticPool + +from app.auth.account_lockout import ( + AccountLockoutService, + get_account_lockout_service, +) +from app.auth.models import User +from app.auth.token_store import ( + TokenStoreService, + get_token_store_service, +) +from app.core.rate_limiter import limiter +from app.core.security import SecurityService +from app.db.session import get_db +from app.main import app +from app.meeting.dependencies import ( + get_meeting_state_service, +) +from app.meeting.state import MeetingStateService +from app.models.base import Base +from app.services.email_producer import get_email_producer_service + +# --------------------------------------------------------------------------- +# Fake Redis +# --------------------------------------------------------------------------- + + +class FakeRedis: + """In-memory stand-in for ``redis.asyncio.Redis``.""" + + def __init__(self) -> None: + self._store: dict[str, str] = {} + self._hashes: dict[str, dict[str, str]] = {} + + # -- String commands -- + async def set( + self, + key: str | None = None, + value: str = "", + ex: int | None = None, # noqa: ARG002 + *, + name: str | None = None, + ) -> None: + final_key = name or key + self._store[final_key] = value + + async def get( + self, key: str | None = None, *, name: str | None = None + ) -> str | None: + final_key = name or key + return self._store.get(final_key) + + async def delete(self, *keys: str) -> None: + for key in keys: + self._store.pop(key, None) + self._hashes.pop(key, None) + + async def exists(self, key: str) -> int: + return 1 if key in self._store or key in self._hashes else 0 + + async def incr(self, key: str) -> int: + current = int(self._store.get(key, "0")) + current += 1 + self._store[key] = str(current) + return current + + async def scan( + self, + cursor: int, # noqa: ARG002 + match: str | None = None, + count: int | None = None, # noqa: ARG002 + ) -> tuple[int, list[str]]: + import fnmatch + + all_keys = list(self._store.keys()) + list(self._hashes.keys()) + matched = ( + [k for k in all_keys if fnmatch.fnmatch(k, match)] if match else all_keys + ) + return 0, matched + + # -- Hash commands -- + async def hset( + self, + name: str = "", + key: str = "", + value: str = "", + ) -> int: + if name not in self._hashes: + self._hashes[name] = {} + self._hashes[name][key] = value + return 1 + + async def hdel(self, name: str, *keys: str) -> int: + if name not in self._hashes: + return 0 + count = 0 + for key in keys: + if key in self._hashes[name]: + del self._hashes[name][key] + count += 1 + return count + + async def hget(self, name: str, key: str) -> str | None: + return self._hashes.get(name, {}).get(key) + + async def hgetall(self, name: str) -> dict[str, str]: + return dict(self._hashes.get(name, {})) + + def pipeline(self) -> "FakePipeline": + return FakePipeline(self) + + def reset(self) -> None: + self._store.clear() + self._hashes.clear() + + +class FakePipeline: + """Minimal pipeline stand-in.""" + + def __init__(self, redis: FakeRedis) -> None: + self._redis = redis + self._ops: list[tuple[str, tuple]] = [] + + def delete(self, key: str) -> "FakePipeline": + self._ops.append(("delete", (key,))) + return self + + def hdel(self, name: str, *keys: str) -> "FakePipeline": + self._ops.append(("hdel", (name, *keys))) + return self + + def hset(self, *, name: str, key: str, value: str) -> "FakePipeline": + self._ops.append(("hset", (name, key, value))) + return self + + async def execute(self) -> list[int]: + results = [] + for op, args in self._ops: + if op == "delete": + self._redis._store.pop(args[0], None) + self._redis._hashes.pop(args[0], None) + results.append(1) + elif op == "hdel": + name = args[0] + keys = args[1:] + count = 0 + if name in self._redis._hashes: + for k in keys: + if k in self._redis._hashes[name]: + del self._redis._hashes[name][k] + count += 1 + results.append(count) + elif op == "hset": + name, key, value = args + if name not in self._redis._hashes: + self._redis._hashes[name] = {} + self._redis._hashes[name][key] = value + results.append(1) + return results + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def db_session() -> Generator[Session, None, None]: + engine = create_engine( + "sqlite+pysqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + TestingSessionLocal = sessionmaker( + autocommit=False, + autoflush=False, + bind=engine, + ) + Base.metadata.create_all(bind=engine) + db = TestingSessionLocal() + try: + yield db + finally: + db.close() + Base.metadata.drop_all(bind=engine) + engine.dispose() + + +@pytest.fixture +def fake_redis() -> FakeRedis: + return FakeRedis() + + +@pytest.fixture +def email_producer_mock() -> AsyncMock: + mock = AsyncMock() + mock.send_email = AsyncMock() + return mock + + +@pytest.fixture +def token_store(fake_redis: FakeRedis) -> TokenStoreService: + return TokenStoreService(redis_client=fake_redis) # type: ignore[arg-type] + + +@pytest.fixture +def meeting_state(fake_redis: FakeRedis) -> MeetingStateService: + return MeetingStateService(redis_client=fake_redis) # type: ignore[arg-type] + + +@pytest.fixture +def lockout_svc(fake_redis: FakeRedis) -> AccountLockoutService: + return AccountLockoutService(redis_client=fake_redis) # type: ignore[arg-type] + + +@pytest_asyncio.fixture +async def client( + db_session: Session, + email_producer_mock: AsyncMock, + token_store: TokenStoreService, + meeting_state: MeetingStateService, + lockout_svc: AccountLockoutService, +) -> httpx.AsyncClient: + def _override_get_db() -> Generator[Session, None, None]: + yield db_session + + def _override_email_producer() -> AsyncMock: + return email_producer_mock + + def _override_token_store() -> TokenStoreService: + return token_store + + def _override_meeting_state() -> MeetingStateService: + return meeting_state + + app.dependency_overrides[get_db] = _override_get_db + app.dependency_overrides[get_email_producer_service] = _override_email_producer + app.dependency_overrides[get_token_store_service] = _override_token_store + app.dependency_overrides[get_meeting_state_service] = _override_meeting_state + + def _override_lockout_svc() -> AccountLockoutService: + return lockout_svc + + app.dependency_overrides[get_account_lockout_service] = _override_lockout_svc + + # Mock the kafka manager to prevent lifespan from bridging actual sockets + import app.main as app_main_module + + mock_kafka = AsyncMock() + app_main_module.get_kafka_manager = lambda: mock_kafka + + limiter.enabled = False + transport = ASGITransport(app=app) + async with httpx.AsyncClient( + transport=transport, base_url="http://test" + ) as async_client: + yield async_client + limiter.enabled = True + app.dependency_overrides.clear() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _seed_user( + db: Session, + *, + email: str = "host@example.com", + password: str = "MyStr0ngP@ss!", + full_name: str = "Test Host", +) -> User: + svc = SecurityService() + user = User( + email=email.lower(), + hashed_password=svc.hash_password(password), + full_name=full_name, + is_active=True, + is_verified=True, + ) + db.add(user) + db.commit() + db.refresh(user) + return user + + +async def _login( + client: httpx.AsyncClient, + email: str = "host@example.com", + password: str = "MyStr0ngP@ss!", +) -> str: + resp = await client.post( + "/api/v1/auth/login", + json={"email": email, "password": password}, + ) + assert resp.status_code == 200, f"Login failed: {resp.json()}" + return resp.json()["access_token"] + + +def _auth_headers(token: str) -> dict[str, str]: + return {"Authorization": f"Bearer {token}"} + + +async def _create_room_via_api( + client: httpx.AsyncClient, + token: str, + name: str = "My Room", +) -> dict: + resp = await client.post( + "/api/v1/meetings/", + json={"name": name}, + headers=_auth_headers(token), + ) + assert resp.status_code == 201 + return resp.json()["data"] + + +# --------------------------------------------------------------------------- +# Test: Create Room +# --------------------------------------------------------------------------- + + +class TestCreateRoomRoute: + @pytest.mark.asyncio + async def test_creates_room_successfully( + self, client: httpx.AsyncClient, db_session: Session + ) -> None: + user = _seed_user(db_session) + token = await _login(client) + + resp = await client.post( + "/api/v1/meetings/", + json={"name": "Team Standup"}, + headers=_auth_headers(token), + ) + + assert resp.status_code == 201 + body = resp.json() + assert body["status"] == "success" + assert body["data"]["name"] == "Team Standup" + assert body["data"]["status"] == "pending" + assert body["data"]["host_id"] == str(user.id) + assert body["data"]["join_url"] is not None + + @pytest.mark.asyncio + async def test_unauthenticated_returns_401(self, client: httpx.AsyncClient) -> None: + resp = await client.post( + "/api/v1/meetings/", + json={"name": "X"}, + ) + assert resp.status_code in (401, 403) + + +# --------------------------------------------------------------------------- +# Test: Get Room Details +# --------------------------------------------------------------------------- + + +class TestGetRoomRoute: + @pytest.mark.asyncio + async def test_get_room_details( + self, client: httpx.AsyncClient, db_session: Session + ) -> None: + _seed_user(db_session) + token = await _login(client) + + room_data = await _create_room_via_api(client, token) + room_code = room_data["room_code"] + + resp = await client.get( + f"/api/v1/meetings/{room_code}", + headers=_auth_headers(token), + ) + + assert resp.status_code == 200 + body = resp.json() + assert body["data"]["room_code"] == room_code + + @pytest.mark.asyncio + async def test_get_nonexistent_room_returns_404( + self, client: httpx.AsyncClient, db_session: Session + ) -> None: + _seed_user(db_session) + token = await _login(client) + + resp = await client.get( + "/api/v1/meetings/DOESNOTEXIST", + headers=_auth_headers(token), + ) + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# Test: Join Room +# --------------------------------------------------------------------------- + + +class TestJoinRoomRoute: + @pytest.mark.asyncio + async def test_host_joins_own_pending_room( + self, client: httpx.AsyncClient, db_session: Session + ) -> None: + _seed_user(db_session) + token = await _login(client) + + room_data = await _create_room_via_api(client, token) + room_code = room_data["room_code"] + + resp = await client.post( + f"/api/v1/meetings/{room_code}/join", + json={"listening_language": "en"}, + headers=_auth_headers(token), + ) + + assert resp.status_code == 200 + body = resp.json() + assert body["data"]["status"] == "joined" + + @pytest.mark.asyncio + async def test_guest_without_name_is_rejected_or_sent_to_lobby( + self, client: httpx.AsyncClient, db_session: Session + ) -> None: + _seed_user(db_session) + token = await _login(client) + + room_data = await _create_room_via_api(client, token) + room_code = room_data["room_code"] + + # Host activates the room first + await client.post( + f"/api/v1/meetings/{room_code}/join", + json={}, + headers=_auth_headers(token), + ) + + # Anonymous guest with no name + resp = await client.post( + f"/api/v1/meetings/{room_code}/join", + json={}, + ) + + # Should be 400 (MISSING_NAME) since no display_name + assert resp.status_code == 400 + + @pytest.mark.asyncio + async def test_join_nonexistent_room_returns_404( + self, client: httpx.AsyncClient, db_session: Session + ) -> None: + _seed_user(db_session) + token = await _login(client) + + resp = await client.post( + "/api/v1/meetings/BADCODE/join", + json={}, + headers=_auth_headers(token), + ) + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# Test: Leave Room +# --------------------------------------------------------------------------- + + +class TestLeaveRoomRoute: + @pytest.mark.asyncio + async def test_host_leaves_room( + self, client: httpx.AsyncClient, db_session: Session + ) -> None: + _seed_user(db_session) + token = await _login(client) + + room_data = await _create_room_via_api(client, token) + room_code = room_data["room_code"] + + # Host joins to activate + await client.post( + f"/api/v1/meetings/{room_code}/join", + json={}, + headers=_auth_headers(token), + ) + + resp = await client.post( + f"/api/v1/meetings/{room_code}/leave", + headers=_auth_headers(token), + ) + + assert resp.status_code == 200 + assert resp.json()["status"] == "success" + + +# --------------------------------------------------------------------------- +# Test: End Room +# --------------------------------------------------------------------------- + + +class TestEndRoomRoute: + @pytest.mark.asyncio + async def test_host_ends_room( + self, client: httpx.AsyncClient, db_session: Session + ) -> None: + _seed_user(db_session) + token = await _login(client) + + room_data = await _create_room_via_api(client, token) + room_code = room_data["room_code"] + + # Host joins to activate the room first + await client.post( + f"/api/v1/meetings/{room_code}/join", + json={}, + headers=_auth_headers(token), + ) + + resp = await client.post( + f"/api/v1/meetings/{room_code}/end", + headers=_auth_headers(token), + ) + + assert resp.status_code == 200 + body = resp.json() + assert body["data"]["status"] == "ended" + + @pytest.mark.asyncio + async def test_non_host_cannot_end_room( + self, client: httpx.AsyncClient, db_session: Session + ) -> None: + _seed_user(db_session, email="host@example.com") + _seed_user(db_session, email="other@example.com") + + host_token = await _login(client, email="host@example.com") + other_token = await _login(client, email="other@example.com") + + room_data = await _create_room_via_api(client, host_token) + room_code = room_data["room_code"] + + resp = await client.post( + f"/api/v1/meetings/{room_code}/end", + headers=_auth_headers(other_token), + ) + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# Test: Update Config +# --------------------------------------------------------------------------- + + +class TestUpdateConfigRoute: + @pytest.mark.asyncio + async def test_host_updates_config( + self, client: httpx.AsyncClient, db_session: Session + ) -> None: + _seed_user(db_session) + token = await _login(client) + + room_data = await _create_room_via_api(client, token) + room_code = room_data["room_code"] + + resp = await client.patch( + f"/api/v1/meetings/{room_code}/config", + json={"lock_room": True}, + headers=_auth_headers(token), + ) + + assert resp.status_code == 200 + body = resp.json() + assert body["data"]["settings"]["lock_room"] is True + + @pytest.mark.asyncio + async def test_non_host_cannot_update_config( + self, client: httpx.AsyncClient, db_session: Session + ) -> None: + _seed_user(db_session, email="host@example.com") + _seed_user(db_session, email="other@example.com") + + host_token = await _login(client, email="host@example.com") + other_token = await _login(client, email="other@example.com") + + room_data = await _create_room_via_api(client, host_token) + room_code = room_data["room_code"] + + resp = await client.patch( + f"/api/v1/meetings/{room_code}/config", + json={"lock_room": True}, + headers=_auth_headers(other_token), + ) + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# Test: Get Live State (Participants) +# --------------------------------------------------------------------------- + + +class TestGetLiveStateRoute: + @pytest.mark.asyncio + async def test_host_gets_live_state( + self, client: httpx.AsyncClient, db_session: Session + ) -> None: + _seed_user(db_session) + token = await _login(client) + + room_data = await _create_room_via_api(client, token) + room_code = room_data["room_code"] + + resp = await client.get( + f"/api/v1/meetings/{room_code}/participants", + headers=_auth_headers(token), + ) + + assert resp.status_code == 200 + body = resp.json() + assert "active" in body["data"] + assert "lobby" in body["data"] + + @pytest.mark.asyncio + async def test_non_host_cannot_get_live_state( + self, client: httpx.AsyncClient, db_session: Session + ) -> None: + _seed_user(db_session, email="host@example.com") + _seed_user(db_session, email="other@example.com") + + host_token = await _login(client, email="host@example.com") + other_token = await _login(client, email="other@example.com") + + room_data = await _create_room_via_api(client, host_token) + room_code = room_data["room_code"] + + resp = await client.get( + f"/api/v1/meetings/{room_code}/participants", + headers=_auth_headers(other_token), + ) + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# Test: Meeting History +# --------------------------------------------------------------------------- + + +class TestMeetingHistoryRoute: + @pytest.mark.asyncio + async def test_returns_empty_history( + self, client: httpx.AsyncClient, db_session: Session + ) -> None: + _seed_user(db_session) + token = await _login(client) + + resp = await client.get( + "/api/v1/meetings/history", + headers=_auth_headers(token), + ) + + assert resp.status_code == 200 + body = resp.json() + assert body["data"]["total"] == 0 + assert body["data"]["items"] == [] + + @pytest.mark.asyncio + async def test_history_after_ended_meeting( + self, client: httpx.AsyncClient, db_session: Session + ) -> None: + _seed_user(db_session) + token = await _login(client) + + room_data = await _create_room_via_api(client, token) + room_code = room_data["room_code"] + + # Join to activate + await client.post( + f"/api/v1/meetings/{room_code}/join", + json={}, + headers=_auth_headers(token), + ) + # End the meeting + await client.post( + f"/api/v1/meetings/{room_code}/end", + headers=_auth_headers(token), + ) + + resp = await client.get( + "/api/v1/meetings/history", + headers=_auth_headers(token), + ) + + assert resp.status_code == 200 + body = resp.json() + assert body["data"]["total"] >= 1 + + +# --------------------------------------------------------------------------- +# Test: Admit User +# --------------------------------------------------------------------------- + + +class TestAdmitUserRoute: + @pytest.mark.asyncio + async def test_admit_nonexistent_user_returns_400( + self, client: httpx.AsyncClient, db_session: Session + ) -> None: + _seed_user(db_session) + token = await _login(client) + + room_data = await _create_room_via_api(client, token) + room_code = room_data["room_code"] + + resp = await client.post( + f"/api/v1/meetings/{room_code}/admit/fake-user-id", + headers=_auth_headers(token), + ) + + # User is not in the lobby, so should return 400 + assert resp.status_code == 400 diff --git a/tests/meeting/test_meeting_service.py b/tests/meeting/test_meeting_service.py index 4605ccb..6a84551 100644 --- a/tests/meeting/test_meeting_service.py +++ b/tests/meeting/test_meeting_service.py @@ -549,9 +549,7 @@ async def test_host_admits_user_from_lobby(self) -> None: await svc.admit_user(host=host, room_code="ABCDEF123456", target_user_id="u99") - state.admit_from_lobby.assert_awaited_once_with( - "ABCDEF123456", "u99", language="en" - ) + state.admit_from_lobby.assert_awaited_once_with("ABCDEF123456", "u99") @pytest.mark.anyio async def test_non_host_cannot_admit(self) -> None: diff --git a/tests/meeting/test_ws.py b/tests/meeting/test_ws.py new file mode 100644 index 0000000..fbc342c --- /dev/null +++ b/tests/meeting/test_ws.py @@ -0,0 +1,127 @@ +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import WebSocket, WebSocketException, status + +from app.meeting.ws_dependencies import assert_room_participant, authenticate_ws +from app.services.connection_manager import ConnectionManager + + +@pytest.fixture +def mock_redis(): + redis = MagicMock() + redis.publish = AsyncMock() + + pubsub = MagicMock() + pubsub.subscribe = AsyncMock() + pubsub.unsubscribe = AsyncMock() + pubsub.listen = AsyncMock() + + redis.pubsub.return_value = pubsub + return redis + + +@pytest.fixture +def connection_manager(mock_redis): + return ConnectionManager(mock_redis) + + +@pytest.mark.asyncio +async def test_connection_manager_connect(connection_manager): + ws = MagicMock(spec=WebSocket) + + await connection_manager.connect("room1", "user1", ws) + + assert "room1" in connection_manager.active_connections + assert connection_manager.active_connections["room1"]["user1"] == ws + assert "room1" in connection_manager._pubsub_tasks + + +@pytest.mark.asyncio +async def test_connection_manager_disconnect(connection_manager): + ws = MagicMock(spec=WebSocket) + ws.accept = AsyncMock() + + await connection_manager.connect("room1", "user1", ws) + connection_manager.disconnect("room1", "user1") + + assert "room1" not in connection_manager.active_connections + assert "room1" not in connection_manager._pubsub_tasks # task is cancelled + + +@pytest.mark.asyncio +async def test_connection_manager_broadcast(connection_manager, mock_redis): + await connection_manager.broadcast_to_room("room1", {"hello": "world"}, "sender1") + + mock_redis.publish.assert_called_once() + args, _ = mock_redis.publish.call_args + assert args[0] == "ws:room:room1" + + payload = json.loads(args[1]) + assert payload["type"] == "broadcast" + assert payload["sender_id"] == "sender1" + assert payload["data"] == {"hello": "world"} + + +@pytest.mark.asyncio +async def test_connection_manager_unicast(connection_manager, mock_redis): + await connection_manager.send_to_user("room1", "target2", {"hello": "world"}) + + mock_redis.publish.assert_called_once() + args, _ = mock_redis.publish.call_args + assert args[0] == "ws:room:room1" + + payload = json.loads(args[1]) + assert payload["type"] == "unicast" + assert payload["target_user_id"] == "target2" + + +@pytest.mark.asyncio +async def test_authenticate_ws_valid_token(): + with patch("app.meeting.ws_dependencies.jwt.decode") as mock_decode: + mock_decode.return_value = {"sub": "user123", "type": "guest"} + + user_id = authenticate_ws("valid_token", db=MagicMock()) + assert user_id == "user123" + + +@pytest.mark.asyncio +async def test_authenticate_ws_invalid_token(): + from jose import JWTError + + with patch( + "app.meeting.ws_dependencies.jwt.decode", side_effect=JWTError("Invalid") + ): + with pytest.raises(WebSocketException) as exc: + authenticate_ws("invalid_token", db=MagicMock()) + + assert exc.value.code == status.WS_1008_POLICY_VIOLATION + + +@pytest.mark.asyncio +async def test_assert_room_participant_valid(): + with patch("app.meeting.ws_dependencies.MeetingStateService") as mock_service_class: + mock_service = MagicMock() + mock_service.get_participants = AsyncMock( + return_value={"user1": {"language": "es"}} + ) + mock_service_class.return_value = mock_service + + state = await assert_room_participant("room1", "user1") + assert state == {"language": "es"} + + +@pytest.mark.asyncio +async def test_assert_room_participant_invalid(): + with patch("app.meeting.ws_dependencies.MeetingStateService") as mock_service_class: + mock_service = MagicMock() + mock_service.get_participants = AsyncMock( + return_value={"user2": {"language": "fr"}} + ) + mock_service_class.return_value = mock_service + + with pytest.raises(WebSocketException) as exc: + await assert_room_participant("room1", "user1") + + assert exc.value.code == status.WS_1008_POLICY_VIOLATION diff --git a/tests/meeting/test_ws_router.py b/tests/meeting/test_ws_router.py new file mode 100644 index 0000000..f71762b --- /dev/null +++ b/tests/meeting/test_ws_router.py @@ -0,0 +1,94 @@ +import json +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + +from app.main import app +from app.meeting.ws_dependencies import authenticate_ws + +# Create a test client +client = TestClient(app) + + +@pytest.fixture(autouse=True) +def override_auth(): + app.dependency_overrides[authenticate_ws] = lambda: "user1" + yield + app.dependency_overrides = {} + + +@pytest.fixture +def mock_room_participant(): + with patch("app.meeting.ws_router.assert_room_participant") as mock: + mock.return_value = {"language": "es"} + yield mock + + +@pytest.fixture +def mock_connection_manager(): + with patch("app.meeting.ws_router.get_connection_manager") as mock_get_cm: + cm = MagicMock() + cm.connect = AsyncMock() + cm.disconnect = MagicMock() + cm.broadcast_to_room = AsyncMock() + cm.send_to_user = AsyncMock() + mock_get_cm.return_value = cm + yield cm + + +@pytest.fixture +def mock_audio_ingest(): + with patch("app.meeting.ws_router.get_audio_ingest_service") as mock_get_ingest: + ingest = MagicMock() + ingest.reset_sequence = MagicMock() + ingest.publish_audio_chunk = AsyncMock() + mock_get_ingest.return_value = ingest + yield ingest + + +@pytest.fixture +def mock_kafka_consumer(): + with patch("app.meeting.ws_router.AIOKafkaConsumer") as mock_consumer_class: + consumer = AsyncMock() + consumer.start = AsyncMock() + consumer.stop = AsyncMock() + mock_consumer_class.return_value = consumer + yield consumer + + +@pytest.mark.usefixtures("mock_room_participant") +def test_signaling_websocket(mock_connection_manager): + # This will connect, send a text message, and then close + with client.websocket_connect( + "/api/v1/ws/signaling/room1?token=mock_token" + ) as websocket: + websocket.send_text(json.dumps({"type": "offer", "target_user_id": "user2"})) + # The connection manager's send_to_user should be called + + mock_connection_manager.connect.assert_called_once() + mock_connection_manager.send_to_user.assert_called_once_with( + "room1", "user2", {"type": "offer", "target_user_id": "user2"} + ) + mock_connection_manager.disconnect.assert_called_once_with("room1", "user1") + mock_connection_manager.broadcast_to_room.assert_called_once_with( + "room1", {"type": "peer_left", "user_id": "user1"}, sender_id="user1" + ) + + +@pytest.mark.usefixtures("mock_room_participant") +def test_audio_websocket_ingest( + mock_audio_ingest, + mock_kafka_consumer, +): + # We will simulate the async iterable for the consumer to yield nothing + mock_kafka_consumer.__aiter__.return_value = [] + + with client.websocket_connect( + "/api/v1/ws/audio/room1?token=mock_token" + ) as websocket: + websocket.send_bytes(b"fake_audio_chunk") + time.sleep(0.1) # Yield to event loop for background tasks to process + + mock_audio_ingest.reset_sequence.assert_called_once_with("room1:user1") diff --git a/tests/test_auth/test_auth_refresh.py b/tests/test_auth/test_auth_refresh.py new file mode 100644 index 0000000..a54c40a --- /dev/null +++ b/tests/test_auth/test_auth_refresh.py @@ -0,0 +1,415 @@ +"""Integration tests for ``POST /api/v1/auth/refresh-token``.""" + +from collections.abc import Generator +from datetime import UTC, datetime +from unittest.mock import AsyncMock + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.pool import StaticPool + +from app.auth.account_lockout import ( + AccountLockoutService, + get_account_lockout_service, +) +from app.auth.models import User +from app.auth.token_store import TokenStoreService, get_token_store_service +from app.core.rate_limiter import limiter +from app.core.security import SecurityService +from app.db.session import get_db +from app.main import app +from app.models.base import Base +from app.services.email_producer import get_email_producer_service + +# --------------------------------------------------------------------------- +# Fake Redis — supports SCAN for revoke_all_user_tokens +# --------------------------------------------------------------------------- + + +class FakeRedis: + """In-memory stand-in for ``redis.asyncio.Redis`` with SCAN support.""" + + def __init__(self) -> None: + self._store: dict[str, str] = {} + + async def set( + self, + key: str, + value: str, + ex: int | None = None, # noqa: ARG002 + ) -> None: + self._store[key] = value + + async def get(self, key: str) -> str | None: + return self._store.get(key) + + async def delete(self, *keys: str) -> None: + for key in keys: + self._store.pop(key, None) + + async def exists(self, key: str) -> int: + return 1 if key in self._store else 0 + + async def incr(self, key: str) -> int: + current = int(self._store.get(key, "0")) + current += 1 + self._store[key] = str(current) + return current + + async def scan( + self, + cursor: int, # noqa: ARG002 + match: str | None = None, + count: int | None = None, # noqa: ARG002 + ) -> tuple[int, list[str]]: + """Return all keys matching *match* pattern in one shot (cursor=0).""" + import fnmatch + + if match: + # Convert Redis glob to fnmatch (Redis uses * for wildcard) + matched = [k for k in self._store if fnmatch.fnmatch(k, match)] + else: + matched = list(self._store.keys()) + # Return cursor=0 to signal iteration complete + return 0, matched + + def pipeline(self) -> "FakePipeline": + return FakePipeline(self) + + def reset(self) -> None: + self._store.clear() + + +class FakePipeline: + """Minimal pipeline stand-in that accumulates delete commands.""" + + def __init__(self, redis: FakeRedis) -> None: + self._redis = redis + self._cmds: list[str] = [] + + def delete(self, key: str) -> "FakePipeline": + self._cmds.append(key) + return self + + async def execute(self) -> None: + for key in self._cmds: + self._redis._store.pop(key, None) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def db_session() -> Generator[Session, None, None]: + engine = create_engine( + "sqlite+pysqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + TestingSessionLocal = sessionmaker( + autocommit=False, + autoflush=False, + bind=engine, + ) + Base.metadata.create_all(bind=engine) + db = TestingSessionLocal() + try: + yield db + finally: + db.close() + Base.metadata.drop_all(bind=engine) + engine.dispose() + + +@pytest.fixture +def fake_redis() -> FakeRedis: + return FakeRedis() + + +@pytest.fixture +def email_producer_mock() -> AsyncMock: + mock = AsyncMock() + mock.send_email = AsyncMock() + return mock + + +@pytest.fixture +def token_store(fake_redis: FakeRedis) -> TokenStoreService: + return TokenStoreService(redis_client=fake_redis) # type: ignore[arg-type] + + +@pytest.fixture +def lockout_svc(fake_redis: FakeRedis) -> AccountLockoutService: + return AccountLockoutService(redis_client=fake_redis) # type: ignore[arg-type] + + +@pytest.fixture +def client( + db_session: Session, + email_producer_mock: AsyncMock, + token_store: TokenStoreService, + lockout_svc: AccountLockoutService, +) -> Generator[TestClient, None, None]: + def _override_get_db() -> Generator[Session, None, None]: + yield db_session + + def _override_email_producer() -> AsyncMock: + return email_producer_mock + + def _override_token_store() -> TokenStoreService: + return token_store + + def _override_lockout_svc() -> AccountLockoutService: + return lockout_svc + + app.dependency_overrides[get_db] = _override_get_db + app.dependency_overrides[get_email_producer_service] = _override_email_producer + app.dependency_overrides[get_token_store_service] = _override_token_store + app.dependency_overrides[get_account_lockout_service] = _override_lockout_svc + + limiter.enabled = False + with TestClient(app) as test_client: + yield test_client + limiter.enabled = True + app.dependency_overrides.clear() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_URL = "/api/v1/auth/refresh-token" +_SECURITY = SecurityService() + + +def _make_refresh_cookie(email: str) -> tuple[str, str, int]: + """Return (raw_token, jti, ttl) for seeding a valid refresh cookie.""" + token, jti, ttl = _SECURITY.create_refresh_token(email=email) + return token, jti, ttl + + +def _seed_user( + db: Session, + email: str = "refresh@example.com", + is_active: bool = True, + deleted_at: datetime | None = None, +) -> User: + user = User( + email=email.lower(), + hashed_password=_SECURITY.hash_password("Passw0rd!"), + full_name="Refresh User", + is_active=is_active, + is_verified=True, + deleted_at=deleted_at, + ) + db.add(user) + db.commit() + db.refresh(user) + return user + + +# --------------------------------------------------------------------------- +# Test cases +# --------------------------------------------------------------------------- + + +class TestRefreshTokenSuccess: + """Happy path: valid rotation returns new tokens and updates cookie.""" + + def test_returns_200_with_new_access_token( + self, + client: TestClient, + db_session: Session, + token_store: TokenStoreService, + ) -> None: + email = "refresh@example.com" + _seed_user(db_session, email=email) + raw_token, jti, ttl = _make_refresh_cookie(email) + + # Manually seed the JTI into fake Redis (as the login endpoint would) + import asyncio + + asyncio.run( + token_store.save_refresh_token(email=email, jti=jti, ttl_seconds=ttl) + ) + + response = client.post(_URL, cookies={"refresh_token": raw_token}) + + assert response.status_code == 200 + body = response.json() + assert "access_token" in body + assert body["token_type"] == "bearer" + assert body["expires_in"] > 0 + + def test_sets_new_httponly_refresh_cookie( + self, + client: TestClient, + db_session: Session, + token_store: TokenStoreService, + ) -> None: + email = "refresh@example.com" + _seed_user(db_session, email=email) + raw_token, jti, ttl = _make_refresh_cookie(email) + + import asyncio + + asyncio.run( + token_store.save_refresh_token(email=email, jti=jti, ttl_seconds=ttl) + ) + + response = client.post(_URL, cookies={"refresh_token": raw_token}) + + assert response.status_code == 200 + # TestClient exposes set-cookie as a header + set_cookie = response.headers.get("set-cookie", "") + assert "refresh_token=" in set_cookie + assert "HttpOnly" in set_cookie + assert "SameSite=strict" in set_cookie + + def test_old_jti_revoked_after_rotation( + self, + client: TestClient, + db_session: Session, + token_store: TokenStoreService, + ) -> None: + email = "refresh@example.com" + _seed_user(db_session, email=email) + raw_token, jti, ttl = _make_refresh_cookie(email) + + import asyncio + + asyncio.run( + token_store.save_refresh_token(email=email, jti=jti, ttl_seconds=ttl) + ) + + client.post(_URL, cookies={"refresh_token": raw_token}) + + # Old JTI must no longer exist in Redis + still_valid = asyncio.run( + token_store.is_refresh_token_valid(email=email, jti=jti) + ) + assert not still_valid + + +class TestRefreshTokenMissingCookie: + """No cookie provided.""" + + def test_returns_401_missing_refresh_token( + self, client: TestClient, db_session: Session + ) -> None: + _seed_user(db_session) + response = client.post(_URL) # no cookie + + assert response.status_code == 401 + assert response.json()["code"] == "MISSING_REFRESH_TOKEN" + + +class TestRefreshTokenInvalid: + """Tampered or expired tokens.""" + + def test_returns_401_for_garbage_token( + self, client: TestClient, db_session: Session + ) -> None: + _seed_user(db_session) + response = client.post(_URL, cookies={"refresh_token": "this.is.garbage"}) + + assert response.status_code == 401 + assert response.json()["code"] == "INVALID_REFRESH_TOKEN" + + def test_returns_401_for_access_token_used_as_refresh( + self, client: TestClient, db_session: Session + ) -> None: + """An access token must not be accepted as a refresh token.""" + _seed_user(db_session) + access_token, _ = _SECURITY.create_access_token(email="refresh@example.com") + + response = client.post(_URL, cookies={"refresh_token": access_token}) + + assert response.status_code == 401 + assert response.json()["code"] == "INVALID_REFRESH_TOKEN" + + +class TestRefreshTokenReuse: + """Replay attack: using a JTI that was already revoked.""" + + def test_returns_401_reuse_and_revokes_all_sessions( + self, + client: TestClient, + db_session: Session, + token_store: TokenStoreService, + ) -> None: + email = "refresh@example.com" + _seed_user(db_session, email=email) + raw_token, _jti, _ttl = _make_refresh_cookie(email) + + import asyncio + + # Seed a second "live" token for the same user to confirm it also + # gets wiped during the breach response. + _, jti2, ttl2 = _make_refresh_cookie(email) + asyncio.run( + token_store.save_refresh_token(email=email, jti=jti2, ttl_seconds=ttl2) + ) + + # Do NOT seed jti (simulate: old token already rotated/revoked) + # Attempt to use it — this is a reuse attack + response = client.post(_URL, cookies={"refresh_token": raw_token}) + + assert response.status_code == 401 + assert response.json()["code"] == "REFRESH_TOKEN_REUSE" + + # The second live token should also be wiped + still_valid = asyncio.run( + token_store.is_refresh_token_valid(email=email, jti=jti2) + ) + assert not still_valid + + +class TestRefreshTokenDeactivatedAccount: + """Account was deactivated after token was issued.""" + + def test_returns_403_for_deleted_account( + self, + client: TestClient, + db_session: Session, + token_store: TokenStoreService, + ) -> None: + email = "gone@example.com" + _seed_user(db_session, email=email, deleted_at=datetime.now(UTC)) + raw_token, jti, ttl = _make_refresh_cookie(email) + + import asyncio + + asyncio.run( + token_store.save_refresh_token(email=email, jti=jti, ttl_seconds=ttl) + ) + + response = client.post(_URL, cookies={"refresh_token": raw_token}) + + assert response.status_code == 403 + assert response.json()["code"] == "ACCOUNT_DEACTIVATED" + + def test_returns_403_for_inactive_account( + self, + client: TestClient, + db_session: Session, + token_store: TokenStoreService, + ) -> None: + email = "inactive@example.com" + _seed_user(db_session, email=email, is_active=False) + raw_token, jti, ttl = _make_refresh_cookie(email) + + import asyncio + + asyncio.run( + token_store.save_refresh_token(email=email, jti=jti, ttl_seconds=ttl) + ) + + response = client.post(_URL, cookies={"refresh_token": raw_token}) + + assert response.status_code == 403 + assert response.json()["code"] == "ACCOUNT_DEACTIVATED" diff --git a/tests/test_auth/test_auth_signup.py b/tests/test_auth/test_auth_signup.py index e12015a..90567fe 100644 --- a/tests/test_auth/test_auth_signup.py +++ b/tests/test_auth/test_auth_signup.py @@ -147,8 +147,8 @@ def test_forgot_password_returns_generic_accepted_response( assert response.status_code == 202 assert response.json() == { "message": ( - "If an account with that email exists, we have sent " - "password reset instructions." + "If an account with that email exists," + " we have sent password reset instructions." ) } email_producer_mock.send_email.assert_not_awaited() diff --git a/tests/test_auth/test_schemas_user.py b/tests/test_auth/test_schemas_user.py index e740af7..cab0c0f 100644 --- a/tests/test_auth/test_schemas_user.py +++ b/tests/test_auth/test_schemas_user.py @@ -15,6 +15,7 @@ def test_user_response_can_validate_from_attributes() -> None: listening_language="fr", is_active=True, is_verified=False, + user_role="user", created_at=datetime.now(UTC), ) diff --git a/tests/test_kafka/test_pipeline.py b/tests/test_kafka/test_pipeline.py new file mode 100644 index 0000000..f5f52fb --- /dev/null +++ b/tests/test_kafka/test_pipeline.py @@ -0,0 +1,192 @@ +import base64 +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.schemas.pipeline import ( + AudioChunkEvent, + AudioChunkPayload, + AudioEncoding, + TranscriptionEvent, + TranscriptionPayload, + TranslationEvent, + TranslationPayload, +) +from app.services.stt_worker import STTWorker +from app.services.translation_worker import TranslationWorker +from app.services.tts_worker import TTSWorker + + +@pytest.fixture +def mock_producer(): + producer = MagicMock() + producer.send = AsyncMock() + return producer + + +@pytest.fixture +def base_audio_chunk_event() -> AudioChunkEvent: + payload = AudioChunkPayload( + room_id="room123", + user_id="user456", + sequence_number=1, + audio_data=base64.b64encode(b"fake_audio").decode("ascii"), + sample_rate=16000, + encoding=AudioEncoding.LINEAR16, + source_language="en", + ) + return AudioChunkEvent(payload=payload) + + +@pytest.fixture +def base_transcription_event() -> TranscriptionEvent: + payload = TranscriptionPayload( + room_id="room123", + user_id="user456", + sequence_number=1, + text="Hello world", + source_language="en", + is_final=True, + confidence=0.95, + ) + return TranscriptionEvent(payload=payload) + + +@pytest.fixture +def base_translation_event() -> TranslationEvent: + payload = TranslationPayload( + room_id="room123", + user_id="user456", + sequence_number=1, + original_text="Hello world", + translated_text="Bonjour le monde", + source_language="en", + target_language="fr", + ) + return TranslationEvent(payload=payload) + + +@pytest.mark.asyncio +async def test_stt_worker_handle(mock_producer, base_audio_chunk_event): + worker = STTWorker(producer=mock_producer) + + with ( + patch("app.services.stt_worker.get_deepgram_stt_service") as mock_get_stt, + patch("app.core.config.settings") as mock_settings, + ): + mock_settings.DEEPGRAM_API_KEY = "fake-key" + + mock_stt_svc = AsyncMock() + mock_stt_svc.transcribe.return_value = { + "text": "Hello audio", + "confidence": 0.99, + "detected_language": "en", + } + mock_get_stt.return_value = mock_stt_svc + + await worker.handle(base_audio_chunk_event) + + mock_stt_svc.transcribe.assert_called_once_with( + b"fake_audio", + language="en", + sample_rate=16000, + encoding="linear16", + ) + mock_producer.send.assert_called_once() + args, kwargs = mock_producer.send.call_args + assert args[0] == "text.original" + assert isinstance(args[1], TranscriptionEvent) + assert args[1].payload.text == "Hello audio" + assert kwargs["key"] == "room123" + + +@pytest.mark.asyncio +async def test_translation_worker_handle(mock_producer, base_transcription_event): + worker = TranslationWorker(producer=mock_producer) + + with ( + patch( + "app.services.translation_worker.MeetingStateService" + ) as _mock_state_class, + patch( + "app.services.translation_worker.get_deepl_translation_service" + ) as mock_get_deepl, + patch("app.services.translation_worker.get_openai_translation_fallback"), + patch("app.core.config.settings") as mock_settings, + ): + mock_settings.DEEPL_API_KEY = "fake-deepl-key" + mock_settings.OPENAI_API_KEY = "fake-openai-key" + + mock_state = AsyncMock() + # Two users with different languages (fr and es) + mock_state.get_participants.return_value = { + "u1": {"language": "fr"}, + "u2": {"language": "es"}, + "u3": {"language": "en"}, # Same as source, should not translate + } + worker._state = mock_state + + mock_deepl = AsyncMock() + mock_deepl.supports_language.return_value = True + mock_deepl.translate.side_effect = ( + lambda _text, _source_language, target_language: { + "translated_text": f"Transl => {target_language}", + "latency_ms": 100, + } + ) + mock_get_deepl.return_value = mock_deepl + + await worker.handle(base_transcription_event) + + # Should translate twice (once for FR, once for ES) + assert mock_deepl.translate.call_count == 2 + assert mock_producer.send.call_count == 2 + + # Verify published events + calls = mock_producer.send.call_args_list + targets = set() + for call in calls: + args, kwargs = call + assert args[0] == "text.translated" + assert isinstance(args[1], TranslationEvent) + targets.add(args[1].payload.target_language) + assert kwargs["key"] == "room123" + + assert targets == {"fr", "es"} + + +@pytest.mark.asyncio +async def test_tts_worker_handle(mock_producer, base_translation_event): + worker = TTSWorker(producer=mock_producer) + + with ( + patch("app.services.tts_worker.get_openai_tts_service") as mock_get_openai, + patch("app.services.tts_worker.settings") as mock_settings, + ): + mock_settings.ACTIVE_TTS_PROVIDER = "openai" + mock_settings.PIPELINE_AUDIO_ENCODING = "linear16" + + mock_openai = AsyncMock() + mock_openai.synthesize.return_value = { + "audio_bytes": b"synthetic_audio_bytes", + "sample_rate": 24000, + } + mock_get_openai.return_value = mock_openai + + await worker.handle(base_translation_event) + + mock_openai.synthesize.assert_called_once_with( + "Bonjour le monde", + encoding="linear16", + ) + + mock_producer.send.assert_called_once() + args, _kwargs = mock_producer.send.call_args + assert args[0] == "audio.synthesized" + + synth_event = args[1] + assert synth_event.payload.sample_rate == 24000 + assert synth_event.payload.target_language == "fr" + + decoded = base64.b64decode(synth_event.payload.audio_data) + assert decoded == b"synthetic_audio_bytes"