diff --git a/backend/test.sh b/backend/test.sh index 3f56ea4c34..13b3066a18 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -75,6 +75,7 @@ pytest tests/unit/test_sync_v2.py -v pytest tests/unit/test_sync_transcription_prefs.py -v pytest tests/unit/test_vision_stream_async.py -v pytest tests/unit/test_desktop_transcribe.py -v +pytest tests/unit/test_dg_start_guard.py -v # Fair-use integration tests (require Redis; skip gracefully if unavailable) if redis-cli ping >/dev/null 2>&1; then diff --git a/backend/tests/unit/test_desktop_transcribe.py b/backend/tests/unit/test_desktop_transcribe.py index 934a605f14..0d99c0a6f8 100644 --- a/backend/tests/unit/test_desktop_transcribe.py +++ b/backend/tests/unit/test_desktop_transcribe.py @@ -468,6 +468,11 @@ def test_no_channels_raises_and_retries(self, mock_client): assert mock_client.listen.rest.v.return_value.transcribe_file.call_count == 3 +# --------------------------------------------------------------------------- +# connect_to_deepgram: start() failure guard (#6302) +# --------------------------------------------------------------------------- + + # --------------------------------------------------------------------------- # Router-level endpoint tests: content-type dispatch and validation # --------------------------------------------------------------------------- diff --git a/backend/tests/unit/test_dg_start_guard.py b/backend/tests/unit/test_dg_start_guard.py new file mode 100644 index 0000000000..d993abef85 --- /dev/null +++ b/backend/tests/unit/test_dg_start_guard.py @@ -0,0 +1,91 @@ +"""Tests for connect_to_deepgram start() guard (#6302). + +Verifies that connect_to_deepgram returns None when dg_connection.start() +returns False, preventing dead connections from being passed to callers. +""" + +import os +import sys +from types import ModuleType +from unittest.mock import MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Minimal stubs — only what streaming.py actually needs at import time +# --------------------------------------------------------------------------- + +# Stub database, heavy deps, and deepgram before importing. +# deepgram stubs must match test_streaming_deepgram_backoff.py pattern to avoid +# import-order pollution when pytest collects both files in the same process. +for _mod_name in [ + 'database', + 'database._client', + 'database.redis_db', + 'database.conversations', + 'database.memories', + 'database.users', + 'firebase_admin', + 'firebase_admin.auth', + 'firebase_admin.messaging', + 'models', + 'models.other', + 'models.transcript_segment', + 'models.chat', + 'models.conversation', + 'models.notification_message', + 'utils.log_sanitizer', + 'deepgram', + 'deepgram.clients', + 'deepgram.clients.live', + 'deepgram.clients.live.v1', + 'websockets', + 'websockets.exceptions', +]: + sys.modules.setdefault(_mod_name, MagicMock()) + +os.environ.setdefault('DEEPGRAM_API_KEY', 'fake-for-test') +# NOTE: Do NOT set sys.modules['deepgram'].LiveTranscriptionEvents here. +# MagicMock auto-generates attributes on access, and overwriting would pollute +# shared pytest state for test_streaming_deepgram_backoff.py's close/error handler tests. + +# Now import the real streaming module +from utils.stt.streaming import connect_to_deepgram + + +class TestConnectToDeepgramStartGuard: + """Verify connect_to_deepgram returns None when start() returns False.""" + + @patch('utils.stt.streaming.deepgram') + def test_returns_none_when_start_fails(self, mock_dg): + """If dg_connection.start() returns False, must return None (#6302).""" + mock_dg_conn = MagicMock() + mock_dg_conn.start.return_value = False + mock_dg.listen.websocket.v.return_value = mock_dg_conn + + result = connect_to_deepgram( + on_message=MagicMock(), + on_error=MagicMock(), + language='en', + sample_rate=16000, + channels=1, + model='nova-3', + ) + assert result is None + + @patch('utils.stt.streaming.deepgram') + def test_returns_connection_when_start_succeeds(self, mock_dg): + """If dg_connection.start() returns True, returns the connection.""" + mock_dg_conn = MagicMock() + mock_dg_conn.start.return_value = True + mock_dg.listen.websocket.v.return_value = mock_dg_conn + + result = connect_to_deepgram( + on_message=MagicMock(), + on_error=MagicMock(), + language='en', + sample_rate=16000, + channels=1, + model='nova-3', + ) + assert result is mock_dg_conn diff --git a/backend/tests/unit/test_streaming_deepgram_backoff.py b/backend/tests/unit/test_streaming_deepgram_backoff.py index 405f9bd9d8..21b493e484 100644 --- a/backend/tests/unit/test_streaming_deepgram_backoff.py +++ b/backend/tests/unit/test_streaming_deepgram_backoff.py @@ -30,11 +30,14 @@ _mock_modules[mod_name] = MagicMock() sys.modules[mod_name] = _mock_modules[mod_name] -# Provide expected attributes for type-annotation imports -sys.modules['deepgram'].DeepgramClient = MagicMock -sys.modules['deepgram'].DeepgramClientOptions = MagicMock -sys.modules['deepgram'].LiveTranscriptionEvents = MagicMock() -sys.modules['deepgram.clients.live.v1'].LiveOptions = MagicMock +# Provide expected attributes only if this file owns the deepgram mock. +# When another test file (e.g. test_dg_start_guard.py) imported streaming.py first, +# overwriting LiveTranscriptionEvents would break event-identity assertions (#6302). +if 'deepgram' in _mock_modules: + sys.modules['deepgram'].DeepgramClient = MagicMock + sys.modules['deepgram'].DeepgramClientOptions = MagicMock + sys.modules['deepgram'].LiveTranscriptionEvents = MagicMock() + sys.modules['deepgram.clients.live.v1'].LiveOptions = MagicMock from utils.stt.streaming import connect_to_deepgram_with_backoff, process_audio_dg # noqa: E402 from utils.stt.streaming import deepgram_options, deepgram_cloud_options # noqa: E402 @@ -289,6 +292,67 @@ async def test_process_audio_dg_no_vad_wrap_on_none(): assert result is None +@pytest.mark.asyncio +async def test_retries_on_none_then_succeeds(): + """When connect_to_deepgram returns None (start()==False), backoff retries and succeeds on later attempt.""" + mock_conn = MagicMock() + call_count = 0 + + def none_then_succeed(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count < 3: + return None # start() returned False + return mock_conn + + sleep_calls = [] + + async def fake_sleep(duration): + sleep_calls.append(duration) + + with patch('utils.stt.streaming.connect_to_deepgram', side_effect=none_then_succeed), patch( + 'utils.stt.streaming.asyncio.sleep', side_effect=fake_sleep + ): + result = await connect_to_deepgram_with_backoff( + on_message=MagicMock(), + on_error=MagicMock(), + language='en', + sample_rate=16000, + channels=1, + model='nova-2-general', + retries=3, + ) + + assert result is mock_conn + assert call_count == 3 + assert len(sleep_calls) == 2 # slept between attempt 1->2 and 2->3 + + +@pytest.mark.asyncio +async def test_returns_none_after_all_none_retries_exhausted(): + """When connect_to_deepgram returns None on all attempts, backoff returns None (not raise).""" + sleep_calls = [] + + async def fake_sleep(duration): + sleep_calls.append(duration) + + with patch('utils.stt.streaming.connect_to_deepgram', return_value=None), patch( + 'utils.stt.streaming.asyncio.sleep', side_effect=fake_sleep + ): + result = await connect_to_deepgram_with_backoff( + on_message=MagicMock(), + on_error=MagicMock(), + language='en', + sample_rate=16000, + channels=1, + model='nova-2-general', + retries=3, + ) + + assert result is None + assert len(sleep_calls) == 2 # slept between retries + + def test_deepgram_options_no_keepalive(): """SDK keepalive option must not be present — it spawns a dangerous background thread (#5870).""" for name, opts in [('deepgram_options', deepgram_options), ('deepgram_cloud_options', deepgram_cloud_options)]: diff --git a/backend/utils/stt/streaming.py b/backend/utils/stt/streaming.py index 756dd79946..387b33e7e9 100644 --- a/backend/utils/stt/streaming.py +++ b/backend/utils/stt/streaming.py @@ -285,9 +285,16 @@ async def connect_to_deepgram_with_backoff( logger.warning("Session ended, aborting Deepgram retry") return None try: - return await asyncio.to_thread( + result = await asyncio.to_thread( connect_to_deepgram, on_message, on_error, language, sample_rate, channels, model, keywords ) + if result is not None: + return result + # start() returned False — retry unless this is the last attempt + if attempt == retries - 1: + logger.error('Deepgram start() returned False on all %d attempts — giving up', retries) + return None + logger.warning('Deepgram start() returned False (attempt %d/%d), retrying...', attempt + 1, retries) except Exception as error: logger.error(f'An error occurred: {error}') if attempt == retries - 1: # Last attempt @@ -361,6 +368,9 @@ def on_unhandled(self, unhandled, **kwargs): result = dg_connection.start(options) logger.info(f'Deepgram connection started: {result}') + if not result: + logger.error('Deepgram connection start() returned False — connection not established') + return None return dg_connection except websockets.exceptions.WebSocketException as e: raise Exception(f'Could not open socket: WebSocketException {e}')