diff --git a/assemblyai/__version__.py b/assemblyai/__version__.py index c2236ea..d720aed 100644 --- a/assemblyai/__version__.py +++ b/assemblyai/__version__.py @@ -1 +1 @@ -__version__ = "0.63.1" +__version__ = "0.64.0" diff --git a/assemblyai/streaming/v3/__init__.py b/assemblyai/streaming/v3/__init__.py index f7be7d3..e89ad55 100644 --- a/assemblyai/streaming/v3/__init__.py +++ b/assemblyai/streaming/v3/__init__.py @@ -11,6 +11,8 @@ StreamingError, StreamingEvents, StreamingParameters, + StreamingPiiPolicy, + StreamingPiiSubstitution, StreamingSessionParameters, TerminationEvent, TurnEvent, @@ -31,6 +33,8 @@ "StreamingError", "StreamingEvents", "StreamingParameters", + "StreamingPiiPolicy", + "StreamingPiiSubstitution", "StreamingSessionParameters", "TerminationEvent", "TurnEvent", diff --git a/assemblyai/streaming/v3/client.py b/assemblyai/streaming/v3/client.py index aaeabee..a9d2581 100644 --- a/assemblyai/streaming/v3/client.py +++ b/assemblyai/streaming/v3/client.py @@ -43,6 +43,12 @@ def _dump_model(model: BaseModel): return model.dict(exclude_none=True) +def _parse_model(model_class, data): + if hasattr(model_class, "model_validate"): + return model_class.model_validate(data) + return model_class.parse_obj(data) + + def _normalize_min_turn_silence(params_dict: dict) -> dict: """Collapse `min_end_of_turn_silence_when_confident` into `min_turn_silence` so only one wire key is ever sent. Emits deprecation warnings.""" @@ -65,6 +71,31 @@ def _normalize_min_turn_silence(params_dict: dict) -> dict: return params_dict +def _normalize_voice_focus(params_dict: dict) -> dict: + """Collapse `noise_suppression_model` / `noise_suppression_threshold` into + `voice_focus` / `voice_focus_threshold` so only the new wire keys are sent. + Emits deprecation warnings.""" + for old_key, new_key in ( + ("noise_suppression_model", "voice_focus"), + ("noise_suppression_threshold", "voice_focus_threshold"), + ): + old = params_dict.pop(old_key, None) + if old is None: + continue + if new_key in params_dict: + logger.warning( + f"[Deprecation Warning] Both `{old_key}` and `{new_key}` are set. " + f"Using `{new_key}`; `{old_key}` is deprecated." + ) + else: + logger.warning( + f"[Deprecation Warning] `{old_key}` is deprecated and will be removed " + f"in a future release. Please use `{new_key}` instead." + ) + params_dict[new_key] = old + return params_dict + + def _dump_model_json(model: BaseModel): if hasattr(model, "model_dump_json"): return model.model_dump_json(exclude_none=True) @@ -94,6 +125,17 @@ def __init__(self, options: StreamingClientOptions): self._write_thread = threading.Thread(target=self._write_message) self._read_thread = threading.Thread(target=self._read_message) self._stop_event = threading.Event() + # Both flags are read and set only on the read thread (or on the main + # thread before workers start, for handshake errors). Plain bools are + # sufficient — no cross-thread synchronization is needed. + self._connection_closed_reported = False + self._server_error_reported = False + # Deliberate single-slot shared-memory handoff: the write thread parks + # a ConnectionClosed here and the read thread drains it. Synchronization + # is provided by `_stop_event.set()` (write side) + `recv(timeout=1)` + # (read side), which together give a happens-before within ~1s. + self._pending_close_error: Optional[Exception] = None + self._websocket = None def connect(self, params: StreamingParameters) -> None: if params.speech_model == "u3-pro": @@ -102,7 +144,15 @@ def connect(self, params: StreamingParameters) -> None: "Please use `u3-rt-pro` instead." ) - params_dict = _normalize_min_turn_silence(_dump_model(params)) + if params.customer_support_audio_capture: + logger.warning( + "`customer_support_audio_capture=True` will record session audio. " + "Only enable this when explicitly coordinating with AssemblyAI support." + ) + + params_dict = _normalize_voice_focus( + _normalize_min_turn_silence(_dump_model(params)) + ) # JSON-encode list and dict parameters for proper API compatibility (e.g., keyterms_prompt, llm_gateway) for key, value in params_dict.items(): @@ -132,8 +182,22 @@ def connect(self, params: StreamingParameters) -> None: additional_headers=headers, open_timeout=15, ) - except websockets.exceptions.ConnectionClosed as exc: - self._handle_error(exc) + except websockets.exceptions.InvalidStatus as exc: + status_code = getattr(getattr(exc, "response", None), "status_code", None) + self._report_connection_closed( + StreamingError( + message=f"WebSocket handshake rejected (HTTP {status_code})", + code=status_code, + ) + ) + return + except ( + websockets.exceptions.InvalidHandshake, + websockets.exceptions.ConnectionClosed, + OSError, + TimeoutError, + ) as exc: + self._report_connection_closed(exc) return self._write_thread.start() @@ -145,23 +209,40 @@ def disconnect(self, terminate: bool = False) -> None: if terminate and not self._stop_event.is_set(): self._write_queue.put(TerminateSession()) - try: - self._read_thread.join() - self._write_thread.join() + self._stop_event.set() + + current = threading.current_thread() + for thread in (self._read_thread, self._write_thread): + if thread is current or not thread.is_alive(): + continue + try: + thread.join() + except RuntimeError as exc: + logger.debug("Thread join skipped: %s", exc) - if self._websocket: - self._websocket.close() - except Exception: - pass + self._close_websocket() + + def _close_websocket(self) -> None: + if not self._websocket: + return + try: + self._websocket.close() + except (OSError, websockets.exceptions.WebSocketException) as exc: + logger.debug("Error closing websocket: %s", exc) def stream( self, data: Union[bytes, Generator[bytes, None, None], Iterable[bytes]] ) -> None: + if self._stop_event.is_set(): + return + if isinstance(data, bytes): self._write_queue.put(data) return for chunk in data: + if self._stop_event.is_set(): + return self._write_queue.put(chunk) def set_params(self, params: StreamingSessionParameters): @@ -178,15 +259,24 @@ def on(self, event: StreamingEvents, handler: Callable) -> None: self._handlers[event].append(handler) def _write_message(self) -> None: - while not self._stop_event.is_set(): + while True: if not self._websocket: raise ValueError("Not connected to the WebSocket server") try: data = self._write_queue.get(timeout=1) except queue.Empty: + if self._stop_event.is_set(): + return continue + # TerminateSession bypasses the stop gate so disconnect(terminate=True) + # can always send it, even when stop is set between put() and the + # write loop's next iteration. + is_terminate = isinstance(data, TerminateSession) + if not is_terminate and self._stop_event.is_set(): + return + try: if isinstance(data, bytes): self._websocket.send(data) @@ -195,20 +285,36 @@ def _write_message(self) -> None: else: raise ValueError(f"Attempted to send invalid message: {type(data)}") except websockets.exceptions.ConnectionClosed as exc: - self._handle_error(exc) + # Defer reporting to the read thread so all on_error dispatch + # happens on a single thread (no cross-thread dedup race). + self._pending_close_error = exc + self._stop_event.set() + return + + if is_terminate: return def _read_message(self) -> None: - while not self._stop_event.is_set(): + while True: if not self._websocket: raise ValueError("Not connected to the WebSocket server") + # Drain a write-thread close before honoring stop, so a stop set by + # the write thread doesn't cause us to exit silently with an + # unreported close. + if self._pending_close_error is not None: + pending, self._pending_close_error = self._pending_close_error, None + self._report_connection_closed(pending) + return + if self._stop_event.is_set(): + return + try: message_data = self._websocket.recv(timeout=1) except TimeoutError: continue except websockets.exceptions.ConnectionClosed as exc: - self._handle_error(exc) + self._report_connection_closed(exc) return try: @@ -220,7 +326,7 @@ def _read_message(self) -> None: message = self._parse_message(message_json) if isinstance(message, ErrorEvent): - self._handle_error(message) + self._report_server_error(message) elif isinstance(message, WarningEvent): self._handle_warning(message) elif message: @@ -244,23 +350,23 @@ def _parse_message(self, data: Dict[str, Any]) -> Optional[EventMessage]: event_type = self._parse_event_type(message_type) if event_type == StreamingEvents.Begin: - return BeginEvent.model_validate(data) + return _parse_model(BeginEvent, data) elif event_type == StreamingEvents.Termination: - return TerminationEvent.model_validate(data) + return _parse_model(TerminationEvent, data) elif event_type == StreamingEvents.Turn: - return TurnEvent.model_validate(data) + return _parse_model(TurnEvent, data) elif event_type == StreamingEvents.SpeechStarted: - return SpeechStartedEvent.model_validate(data) + return _parse_model(SpeechStartedEvent, data) elif event_type == StreamingEvents.LLMGatewayResponse: - return LLMGatewayResponseEvent.model_validate(data) + return _parse_model(LLMGatewayResponseEvent, data) elif event_type == StreamingEvents.Error: - return ErrorEvent.model_validate(data) + return _parse_model(ErrorEvent, data) elif event_type == StreamingEvents.Warning: - return WarningEvent.model_validate(data) + return _parse_model(WarningEvent, data) else: return None elif "error" in data: - return ErrorEvent.model_validate(data) + return _parse_model(ErrorEvent, data) return None @@ -281,44 +387,85 @@ def _handle_warning(self, warning: WarningEvent): for handler in self._handlers[StreamingEvents.Warning]: handler(self, warning) - def _handle_error( + def _report_server_error(self, error: ErrorEvent) -> None: + self._server_error_reported = True + streaming_error = StreamingError( + message=error.error, + code=error.error_code, + ) + logger.error("Streaming error: %s (code=%s)", error.error, error.error_code) + self._dispatch_error(streaming_error) + + def _report_connection_closed( self, error: Union[ + StreamingError, ErrorEvent, websockets.exceptions.ConnectionClosed, + OSError, ], - ): - parsed_error = self._parse_error(error) + ) -> None: + # Idempotent: defensive guard in case future callers (e.g. another + # connect-time error path) reach this method twice. + if self._connection_closed_reported: + return + self._connection_closed_reported = True + self._stop_event.set() - for handler in self._handlers[StreamingEvents.Error]: - handler(self, parsed_error) + streaming_error = self._build_connection_closed_error(error) - self.disconnect() + # Clean close (code 1000) → no streaming_error, nothing to report. + if streaming_error is None: + self._close_websocket() + return - def _parse_error( - self, + if isinstance(error, websockets.exceptions.ConnectionClosed): + reason = error.reason or "no reason given" + logger.error("Connection closed: %s (code=%s)", reason, error.code) + else: + logger.error( + "Connection failed: %s (code=%s)", + streaming_error, + streaming_error.code, + ) + + # If a server Error frame already fired on_error, the close is the + # effect, not a new cause — log it (above) but skip the duplicate + # user-visible error. + if not self._server_error_reported: + self._dispatch_error(streaming_error) + + self._close_websocket() + + def _dispatch_error(self, error: StreamingError) -> None: + for handler in self._handlers[StreamingEvents.Error]: + try: + handler(self, error) + except Exception: + logger.exception("on_error handler raised") + + @staticmethod + def _build_connection_closed_error( error: Union[ + StreamingError, ErrorEvent, websockets.exceptions.ConnectionClosed, + OSError, ], - ) -> StreamingError: + ) -> Optional[StreamingError]: + if isinstance(error, StreamingError): + return error if isinstance(error, ErrorEvent): - return StreamingError( - message=error.error, - code=error.error_code, - ) - elif isinstance(error, websockets.exceptions.ConnectionClosed): - if error.code in StreamingErrorCodes: - error_message = StreamingErrorCodes[error.code] + return StreamingError(message=error.error, code=error.error_code) + if isinstance(error, websockets.exceptions.ConnectionClosed): + if error.code == 1000: + return None + if error.code is not None and error.code in StreamingErrorCodes: + message = StreamingErrorCodes[error.code] else: - error_message = error.reason - - if error.code != 1000: - return StreamingError(message=error_message, code=error.code) - - return StreamingError( - message=f"Unknown error: {error}", - ) + message = error.reason or f"Connection closed (code={error.code})" + return StreamingError(message=message, code=error.code) + return StreamingError(message=f"Connection failed: {error}") def create_temporary_token( self, diff --git a/assemblyai/streaming/v3/models.py b/assemblyai/streaming/v3/models.py index 40b4def..4115f9f 100644 --- a/assemblyai/streaming/v3/models.py +++ b/assemblyai/streaming/v3/models.py @@ -22,6 +22,7 @@ class Word(BaseModel): confidence: float text: str word_is_final: bool + speaker: Optional[str] = None class TurnEvent(BaseModel): @@ -140,6 +141,67 @@ def __str__(self): return self.value +class StreamingPiiSubstitution(str, Enum): + hash = "hash" + entity_name = "entity_name" + + def __str__(self): + return self.value + + +class StreamingPiiPolicy(str, Enum): + account_number = "account_number" + banking_information = "banking_information" + blood_type = "blood_type" + credit_card_number = "credit_card_number" + credit_card_expiration = "credit_card_expiration" + credit_card_cvv = "credit_card_cvv" + date = "date" + date_interval = "date_interval" + date_of_birth = "date_of_birth" + drivers_license = "drivers_license" + drug = "drug" + duration = "duration" + email_address = "email_address" + event = "event" + filename = "filename" + gender_sexuality = "gender_sexuality" + gender = "gender" + healthcare_number = "healthcare_number" + injury = "injury" + ip_address = "ip_address" + language = "language" + location = "location" + marital_status = "marital_status" + medical_condition = "medical_condition" + medical_process = "medical_process" + money_amount = "money_amount" + nationality = "nationality" + number_sequence = "number_sequence" + passport_number = "passport_number" + password = "password" + person_age = "person_age" + person_name = "person_name" + phone_number = "phone_number" + physical_attribute = "physical_attribute" + political_affiliation = "political_affiliation" + occupation = "occupation" + organization = "organization" + organization_medical_facility = "organization_medical_facility" + religion = "religion" + sexuality = "sexuality" + statistics = "statistics" + time = "time" + url = "url" + us_social_security_number = "us_social_security_number" + username = "username" + vehicle_id = "vehicle_id" + zodiac_sign = "zodiac_sign" + + def __str__(self): + return self.value + + class StreamingParameters(StreamingSessionParameters): sample_rate: int encoding: Optional[Encoding] = None @@ -153,8 +215,17 @@ class StreamingParameters(StreamingSessionParameters): llm_gateway: Optional[LLMGatewayConfig] = None speaker_labels: Optional[bool] = None max_speakers: Optional[int] = None + voice_focus: Optional[NoiseSuppressionModel] = None + voice_focus_threshold: Optional[float] = None + # Deprecated: use voice_focus / voice_focus_threshold instead. noise_suppression_model: Optional[NoiseSuppressionModel] = None noise_suppression_threshold: Optional[float] = None + continuous_partials: Optional[bool] = None + customer_support_audio_capture: Optional[bool] = None + include_partial_turns: Optional[bool] = None + redact_pii: Optional[bool] = None + redact_pii_policies: Optional[List[StreamingPiiPolicy]] = None + redact_pii_sub: Optional[StreamingPiiSubstitution] = None class UpdateConfiguration(StreamingSessionParameters): diff --git a/tests/unit/test_streaming.py b/tests/unit/test_streaming.py index c1b54d7..8151899 100644 --- a/tests/unit/test_streaming.py +++ b/tests/unit/test_streaming.py @@ -1,8 +1,15 @@ +import json +import logging +import threading +import time +from types import SimpleNamespace from urllib.parse import urlencode import pytest from pydantic import ValidationError from pytest_mock import MockFixture +from websockets.exceptions import ConnectionClosed, InvalidStatus +from websockets.frames import Close from assemblyai.streaming.v3 import ( NoiseSuppressionModel, @@ -10,9 +17,14 @@ SpeechStartedEvent, StreamingClient, StreamingClientOptions, + StreamingEvents, StreamingParameters, + StreamingPiiPolicy, + StreamingPiiSubstitution, TurnEvent, + Word, ) +from assemblyai.streaming.v3.models import TerminateSession def _disable_rw_threads(mocker: MockFixture): @@ -157,7 +169,7 @@ def mocked_websocket_connect( assert actual_open_timeout == 15 -def test_client_connect_with_noise_suppression(mocker: MockFixture): +def test_client_connect_with_redact_pii(mocker: MockFixture): actual_url = None def mocked_websocket_connect( @@ -179,19 +191,153 @@ def mocked_websocket_connect( params = StreamingParameters( sample_rate=16000, speech_model=SpeechModel.universal_streaming_english, - noise_suppression_model=NoiseSuppressionModel.near_field, - noise_suppression_threshold=0.5, + include_partial_turns=False, + redact_pii=True, + redact_pii_policies=[ + StreamingPiiPolicy.email_address, + StreamingPiiPolicy.phone_number, + ], + redact_pii_sub=StreamingPiiSubstitution.entity_name, ) client.connect(params) - expected_headers = { - "sample_rate": params.sample_rate, - "speech_model": str(params.speech_model), - "noise_suppression_model": str(params.noise_suppression_model), - "noise_suppression_threshold": params.noise_suppression_threshold, - } + assert "include_partial_turns=False" in actual_url + assert "redact_pii=True" in actual_url + assert "redact_pii_sub=entity_name" in actual_url + assert "redact_pii_policies=" in actual_url + assert "email_address" in actual_url + assert "phone_number" in actual_url - assert actual_url == f"wss://api.example.com/v3/ws?{urlencode(expected_headers)}" + +def test_client_connect_with_voice_focus(mocker: MockFixture): + # Given: client + voice_focus parameters + actual_url = None + + def mocked_websocket_connect( + url: str, additional_headers: dict, open_timeout: float + ): + nonlocal actual_url + actual_url = url + + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + new=mocked_websocket_connect, + ) + _disable_rw_threads(mocker) + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + params = StreamingParameters( + sample_rate=16000, + speech_model=SpeechModel.universal_streaming_english, + voice_focus=NoiseSuppressionModel.near_field, + voice_focus_threshold=0.5, + ) + + # When: connect + client.connect(params) + + # Then: new wire keys are sent; old keys never appear + assert "voice_focus=near-field" in actual_url + assert "voice_focus_threshold=0.5" in actual_url + assert "noise_suppression_model" not in actual_url + assert "noise_suppression_threshold" not in actual_url + + +def test_noise_suppression_deprecated_alias_migrates_to_voice_focus( + mocker: MockFixture, caplog: pytest.LogCaptureFixture +): + # Given: a client passing the legacy noise_suppression_* fields + actual_url = None + + def mocked_websocket_connect( + url: str, additional_headers: dict, open_timeout: float + ): + nonlocal actual_url + actual_url = url + + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + new=mocked_websocket_connect, + ) + _disable_rw_threads(mocker) + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + params = StreamingParameters( + sample_rate=16000, + speech_model=SpeechModel.universal_streaming_english, + noise_suppression_model=NoiseSuppressionModel.far_field, + noise_suppression_threshold=0.7, + ) + + # When: connect + with caplog.at_level(logging.WARNING): + client.connect(params) + + # Then: legacy values migrate to voice_focus_* on the wire and a deprecation + # warning is logged for each migrated field + assert "voice_focus=far-field" in actual_url + assert "voice_focus_threshold=0.7" in actual_url + assert "noise_suppression_model" not in actual_url + assert "noise_suppression_threshold" not in actual_url + assert any( + "noise_suppression_model" in r.message and "deprecated" in r.message + for r in caplog.records + ) + assert any( + "noise_suppression_threshold" in r.message and "deprecated" in r.message + for r in caplog.records + ) + + +def test_voice_focus_conflict_prefers_new_name( + mocker: MockFixture, caplog: pytest.LogCaptureFixture +): + # Given: both legacy and new fields are set + actual_url = None + + def mocked_websocket_connect( + url: str, additional_headers: dict, open_timeout: float + ): + nonlocal actual_url + actual_url = url + + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + new=mocked_websocket_connect, + ) + _disable_rw_threads(mocker) + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + params = StreamingParameters( + sample_rate=16000, + speech_model=SpeechModel.universal_streaming_english, + voice_focus=NoiseSuppressionModel.near_field, + voice_focus_threshold=0.4, + noise_suppression_model=NoiseSuppressionModel.far_field, + noise_suppression_threshold=0.9, + ) + + # When: connect + with caplog.at_level(logging.WARNING): + client.connect(params) + + # Then: voice_focus wins; conflict warning logged for each field + assert "voice_focus=near-field" in actual_url + assert "voice_focus_threshold=0.4" in actual_url + assert "noise_suppression_model" not in actual_url + assert "noise_suppression_threshold" not in actual_url + assert any( + "Both `noise_suppression_model` and `voice_focus` are set" in r.message + for r in caplog.records + ) + assert any( + "Both `noise_suppression_threshold` and `voice_focus_threshold` are set" + in r.message + for r in caplog.records + ) def test_api_host_accepts_ws_scheme(mocker: MockFixture): @@ -473,6 +619,106 @@ def mocked_websocket_connect( assert "max_speakers=3" in actual_url +def test_client_connect_with_continuous_partials(mocker: MockFixture): + # Given: client + continuous_partials=True (U3-Pro steady-partials mode) + actual_url = None + + def mocked_websocket_connect( + url: str, additional_headers: dict, open_timeout: float + ): + nonlocal actual_url + actual_url = url + + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + new=mocked_websocket_connect, + ) + _disable_rw_threads(mocker) + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + params = StreamingParameters( + sample_rate=16000, + speech_model=SpeechModel.u3_rt_pro, + continuous_partials=True, + ) + + # When: connect + client.connect(params) + + # Then: parameter reaches the URL + assert "continuous_partials=True" in actual_url + + +def test_customer_support_audio_capture_warns_when_enabled( + mocker: MockFixture, caplog: pytest.LogCaptureFixture +): + # Given: client + customer_support_audio_capture=True + actual_url = None + + def mocked_websocket_connect( + url: str, additional_headers: dict, open_timeout: float + ): + nonlocal actual_url + actual_url = url + + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + new=mocked_websocket_connect, + ) + _disable_rw_threads(mocker) + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + params = StreamingParameters( + sample_rate=16000, + speech_model=SpeechModel.universal_streaming_english, + customer_support_audio_capture=True, + ) + + # When: connect + with caplog.at_level(logging.WARNING): + client.connect(params) + + # Then: parameter reaches the URL and a warning is logged + assert "customer_support_audio_capture=True" in actual_url + assert any( + "session audio" in r.message and "support" in r.message for r in caplog.records + ) + + +def test_customer_support_audio_capture_no_warning_when_disabled( + mocker: MockFixture, caplog: pytest.LogCaptureFixture +): + # Given: client without customer_support_audio_capture + def mocked_websocket_connect( + url: str, additional_headers: dict, open_timeout: float + ): + pass + + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + new=mocked_websocket_connect, + ) + _disable_rw_threads(mocker) + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + params = StreamingParameters( + sample_rate=16000, + speech_model=SpeechModel.universal_streaming_english, + ) + + # When: connect + with caplog.at_level(logging.WARNING): + client.connect(params) + + # Then: no support-audio warning is logged + assert not any( + "session audio" in r.message and "support" in r.message for r in caplog.records + ) + + def test_client_connect_with_whisper_rt(mocker: MockFixture): actual_url = None @@ -531,6 +777,78 @@ def test_turn_event_without_speaker_label(): assert event.speaker_label is None +def test_word_with_speaker_field(): + # Given: a Word payload that includes a per-word speaker label + data = { + "start": 100, + "end": 250, + "confidence": 0.92, + "text": "hello", + "word_is_final": True, + "speaker": "A", + } + + # When: parsed + word = Word.parse_obj(data) + + # Then: the speaker label is preserved + assert word.speaker == "A" + + +def test_word_without_speaker_field_defaults_to_none(): + # Given: a Word payload that omits the speaker label + data = { + "start": 100, + "end": 250, + "confidence": 0.92, + "text": "hello", + "word_is_final": True, + } + + # When: parsed + word = Word.parse_obj(data) + + # Then: speaker is optional → None + assert word.speaker is None + + +def test_turn_event_with_word_speakers(): + # Given: a TurnEvent with two words carrying distinct per-word speaker labels + data = { + "type": "Turn", + "turn_order": 1, + "turn_is_formatted": True, + "end_of_turn": True, + "transcript": "Hello world", + "end_of_turn_confidence": 0.85, + "words": [ + { + "start": 0, + "end": 100, + "confidence": 0.9, + "text": "Hello", + "word_is_final": True, + "speaker": "A", + }, + { + "start": 110, + "end": 200, + "confidence": 0.9, + "text": "world", + "word_is_final": True, + "speaker": "B", + }, + ], + "speaker_label": "A", + } + + # When: parsed + event = TurnEvent.parse_obj(data) + + # Then: each word's speaker is preserved + assert [w.speaker for w in event.words] == ["A", "B"] + + def test_speech_model_required(): """Test that omitting speech_model raises a validation error.""" with pytest.raises(ValidationError): @@ -546,3 +864,343 @@ def test_speech_started_event(): event = SpeechStartedEvent.parse_obj(data) assert event.type == "SpeechStarted" assert event.timestamp == 1280 + + +class _FakeWebSocket: + """Programmable sync websocket stand-in for driving StreamingClient in tests.""" + + def __init__(self, recv_script, send_raises=None, send_blocks_until=None): + self._recv_script = list(recv_script) + self._send_raises = send_raises + # Optional Event the test can use to hold send() until a barrier point + # (e.g. "release send only after the read thread has reached + # _report_server_error"), making the read+write race deterministic. + self._send_blocks_until = send_blocks_until + self._recv_lock = threading.Lock() + self.close_call_count = 0 + self.send_call_count = 0 + self.sent = [] + + def recv(self, timeout=None): + with self._recv_lock: + if not self._recv_script: + raise TimeoutError() + item = self._recv_script.pop(0) + if isinstance(item, BaseException): + raise item + return item + + def send(self, data): + self.send_call_count += 1 + if self._send_blocks_until is not None: + self._send_blocks_until.wait(timeout=2.0) + if self._send_raises is not None: + raise self._send_raises + self.sent.append(data) + + def close(self): + self.close_call_count += 1 + + +def _connect_and_wait(client, params, seed_chunks=None, timeout=2.0): + # Prime the write queue BEFORE connect so the write thread's first get() + # returns immediately. If we primed after connect, client.stream() can + # short-circuit on _stop_event (set by a fast-firing close path) and the + # write thread parks in get(timeout=1) for a full second, never reaching + # send() — which means the read+write race the tests target never happens. + if seed_chunks is not None: + for chunk in seed_chunks: + client._write_queue.put(chunk) + client.connect(params) + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + read_done = ( + not client._read_thread.is_alive() + if client._read_thread.ident is not None + else True + ) + write_done = ( + not client._write_thread.is_alive() + if client._write_thread.ident is not None + else True + ) + if read_done and write_done and client._stop_event.is_set(): + return + time.sleep(0.02) + + +def _default_params(): + return StreamingParameters( + sample_rate=16000, + speech_model=SpeechModel.universal_streaming_english, + ) + + +def test_error_event_then_close_fires_only_once( + mocker: MockFixture, caplog: pytest.LogCaptureFixture +): + # Given: server Error then close + a barrier that holds send() until the + # read thread enters _report_server_error, guaranteeing a real read+write + # race on the close (not just an artifact of recv_script ordering). + caplog.set_level(logging.ERROR) + error_json = json.dumps( + {"type": "Error", "error": "Invalid API key", "error_code": 4001} + ) + close_exc = ConnectionClosed(rcvd=Close(4001, "Not Authorized"), sent=None) + send_gate = threading.Event() + fake_ws = _FakeWebSocket( + recv_script=[error_json, close_exc], + send_raises=close_exc, + send_blocks_until=send_gate, + ) + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + return_value=fake_ws, + ) + real_report_server_error = StreamingClient._report_server_error + + def report_server_error_then_release(self, error): + send_gate.set() + return real_report_server_error(self, error) + + mocker.patch.object( + StreamingClient, "_report_server_error", report_server_error_then_release + ) + received = [] + received_lock = threading.Lock() + + def on_error(self_, err): + with received_lock: + received.append(err) + + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Error, on_error) + + # When: connect with a primed write queue so the write thread reaches send() + # and parks on the barrier; the read thread races when it releases the gate. + _connect_and_wait( + client, + _default_params(), + seed_chunks=[b"\x00" * 320] * 50, + ) + + # Then: exactly one on_error with the rich server-error content. + assert len(received) == 1, ( + f"expected exactly 1 error, got {len(received)}: {received}" + ) + assert str(received[0]) == "Invalid API key" + assert received[0].code == 4001 + assert fake_ws.close_call_count >= 1 + assert fake_ws.send_call_count >= 1, "write thread never reached send()" + assert not client._read_thread.is_alive() + assert not client._write_thread.is_alive() + error_logs = [ + rec + for rec in caplog.records + if "Streaming error" in rec.message and "4001" in rec.message + ] + close_logs = [ + rec + for rec in caplog.records + if "Connection closed" in rec.message and "4001" in rec.message + ] + assert len(error_logs) == 1, ( + f"expected exactly 1 Streaming-error log, got {len(error_logs)}" + ) + assert error_logs[0].levelno == logging.ERROR + assert len(close_logs) == 1, ( + f"expected exactly 1 Connection-closed log, got {len(close_logs)}" + ) + assert close_logs[0].levelno == logging.ERROR + + client.disconnect(terminate=True) + + +def test_handler_exception_does_not_block_shutdown(mocker: MockFixture): + # Given: a websocket that raises ConnectionClosed and an on_error handler that throws + close_exc = ConnectionClosed(rcvd=Close(1011, "server error"), sent=None) + fake_ws = _FakeWebSocket(recv_script=[close_exc], send_raises=close_exc) + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + return_value=fake_ws, + ) + + def bad_handler(self_, err): + raise RuntimeError("boom") + + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Error, bad_handler) + + # When: the client connects and the handler raises during error dispatch + _connect_and_wait(client, _default_params()) + + # Then: cleanup still completes — websocket closed, both worker threads exited + assert fake_ws.close_call_count >= 1 + assert not client._read_thread.is_alive() + assert not client._write_thread.is_alive() + + client.disconnect(terminate=True) + + +def test_invalid_status_401_during_connect(mocker: MockFixture): + # Given: websocket_connect raises InvalidStatus carrying an HTTP 401 response + response = SimpleNamespace(status_code=401) + invalid_status = InvalidStatus(response=response) + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + side_effect=invalid_status, + ) + start_spy = mocker.spy(threading.Thread, "start") + received = [] + + def on_error(self_, err): + received.append(err) + + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Error, on_error) + + # When: connect() is called and the handshake is rejected + client.connect(_default_params()) + + # Then: a single error with code=401 is dispatched, and neither worker thread is started + assert len(received) == 1 + assert received[0].code == 401 + assert not client._read_thread.is_alive() + assert not client._write_thread.is_alive() + for call in start_spy.call_args_list: + assert call.args[0] not in (client._read_thread, client._write_thread) + + client.disconnect() + + +def test_clean_close_emits_no_error_or_log( + mocker: MockFixture, caplog: pytest.LogCaptureFixture +): + # Given: a code-1000 ConnectionClosed delivered to the read thread (exercises + # the `streaming_error is None` short-circuit in _report_connection_closed). + caplog.set_level(logging.DEBUG) + clean_close = ConnectionClosed(rcvd=Close(1000, "session ended"), sent=None) + fake_ws = _FakeWebSocket(recv_script=[clean_close], send_raises=clean_close) + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + return_value=fake_ws, + ) + received = [] + + def on_error(self_, err): + received.append(err) + + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Error, on_error) + + # When: the client connects and the read thread processes the clean close + _connect_and_wait(client, _default_params()) + + # Then: no on_error fires; no WARNING/ERROR-level log mentions the close. + # The close path took the `streaming_error is None` short-circuit. + assert received == [], f"unexpected on_error calls: {received}" + fatal_logs = [rec for rec in caplog.records if rec.levelno >= logging.WARNING] + assert fatal_logs == [], ( + f"clean close should not emit WARNING/ERROR logs, got: " + f"{[(r.levelname, r.message) for r in fatal_logs]}" + ) + + client.disconnect() + + +def test_report_connection_closed_suppresses_dispatch_when_server_error_flag_set( + mocker: MockFixture, +): + # Given: a client with _server_error_reported pre-set (simulating "read + # thread already dispatched the rich server error"). + fake_ws = _FakeWebSocket(recv_script=[]) + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + return_value=fake_ws, + ) + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + received = [] + client.on(StreamingEvents.Error, lambda c, e: received.append(e)) + client._websocket = fake_ws + client._server_error_reported = True + + # When: _report_connection_closed runs for the trailing close + close_exc = ConnectionClosed( + rcvd=Close(4001, "See Error message for details"), sent=None + ) + client._report_connection_closed(close_exc) + + # Then: no on_error dispatch (close was logged but the duplicate suppressed) + assert received == [], ( + f"close path dispatched despite server error already reported: {received}" + ) + + +def test_disconnect_terminate_sends_terminate_after_stop_set(mocker: MockFixture): + # Given: a client with a TerminateSession queued AND _stop_event already set, + # simulating the race where disconnect(terminate=True) puts then sets stop + # before the write loop's next get(). + fake_ws = _FakeWebSocket(recv_script=[]) + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + return_value=fake_ws, + ) + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client._websocket = fake_ws + client._write_queue.put(TerminateSession()) + client._stop_event.set() + + # When: the write loop runs (in the test thread) + client._write_message() + + # Then: TerminateSession was sent despite stop being set; loop exited cleanly. + assert fake_ws.send_call_count == 1, ( + f"expected exactly 1 send (the TerminateSession), got {fake_ws.send_call_count}" + ) + assert len(fake_ws.sent) == 1 + assert '"Terminate"' in fake_ws.sent[0] + + +def test_write_thread_close_is_drained_by_read_thread(mocker: MockFixture): + # Given: recv() always times out (read thread never sees its own close) + # but send() raises ConnectionClosed, forcing the _pending_close_error + # drain path to be the only way the user can be notified. + close_exc = ConnectionClosed(rcvd=Close(1011, "boom"), sent=None) + fake_ws = _FakeWebSocket(recv_script=[], send_raises=close_exc) + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + return_value=fake_ws, + ) + received = [] + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Error, lambda c, e: received.append(e)) + + # When: connect with seeded audio so the write thread reaches send() + _connect_and_wait( + client, + _default_params(), + seed_chunks=[b"\x00" * 320] * 5, + ) + + # Then: the read thread drained the pending close and dispatched once. + assert fake_ws.send_call_count >= 1, "write thread never reached send()" + assert len(received) == 1, ( + f"expected exactly 1 error from drained pending close, got: {received}" + ) + assert received[0].code == 1011 + + client.disconnect()