From e926ee4d92c78187abfe28f64f4bf43ca39c85ee Mon Sep 17 00:00:00 2001 From: AssemblyAI Date: Tue, 19 May 2026 14:34:21 -0600 Subject: [PATCH] Project import generated by Copybara. GitOrigin-RevId: 795e6acd1a043790a5cdc27079636451694fc219 --- README.md | 226 +++-- assemblyai/__init__.py | 19 +- assemblyai/__version__.py | 2 +- assemblyai/api.py | 23 - assemblyai/streaming/v3/__init__.py | 4 + assemblyai/streaming/v3/_base.py | 267 ++++++ assemblyai/streaming/v3/async_client.py | 512 ++++++++++ assemblyai/streaming/v3/client.py | 250 +---- assemblyai/streaming/v3/models.py | 1 + assemblyai/transcriber.py | 453 +-------- assemblyai/types.py | 164 +--- tests/unit/test_realtime_transcriber.py | 564 ----------- tests/unit/test_streaming.py | 31 + tests/unit/test_streaming_async.py | 1153 +++++++++++++++++++++++ tox.ini | 7 + 15 files changed, 2191 insertions(+), 1485 deletions(-) create mode 100644 assemblyai/streaming/v3/_base.py create mode 100644 assemblyai/streaming/v3/async_client.py delete mode 100644 tests/unit/test_realtime_transcriber.py create mode 100644 tests/unit/test_streaming_async.py diff --git a/README.md b/README.md index 76e8c0b..72154b7 100644 --- a/README.md +++ b/README.md @@ -699,79 +699,187 @@ for result in transcript.auto_highlights.results: ### **Streaming Examples** -[Read more about our streaming service.](https://www.assemblyai.com/docs/streaming/universal-3-pro) +Real-time speech-to-text via WebSocket against the `u3-rt-pro` model. The SDK ships two clients with identical option/event/handler surfaces — `StreamingClient` (threaded) and `AsyncStreamingClient` (asyncio). Pick whichever fits your codebase. + +**Handler contract**: every handler is called as `handler(client, event)`. Plain functions and `async def` functions both work; `AsyncStreamingClient` awaits async handlers inline on the read task, so don't block — use `asyncio.create_task(...)` if you need concurrent work. + +[Read more about the streaming service.](https://www.assemblyai.com/docs/streaming/universal-3-pro) + +
+ Stream a local file (sync) + +```python +import assemblyai as aai +from assemblyai.streaming.v3 import ( + BeginEvent, StreamingClient, StreamingClientOptions, StreamingError, + StreamingEvents, StreamingParameters, TerminationEvent, TurnEvent, +) + +def on_begin(client, event: BeginEvent): + print(f"Session started: {event.id}") + +def on_turn(client, event: TurnEvent): + print(f"{event.transcript} (end_of_turn={event.end_of_turn})") + +def on_terminated(client, event: TerminationEvent): + print(f"Done: {event.audio_duration_seconds}s of audio processed") + +def on_error(client, error: StreamingError): + print(f"Error: {error} (code={error.code})") + +client = StreamingClient(StreamingClientOptions(api_key="")) +client.on(StreamingEvents.Begin, on_begin) +client.on(StreamingEvents.Turn, on_turn) +client.on(StreamingEvents.Termination, on_terminated) +client.on(StreamingEvents.Error, on_error) + +client.connect(StreamingParameters( + sample_rate=16000, speech_model="u3-rt-pro", format_turns=True, +)) +try: + client.stream(aai.extras.stream_file(filepath="audio.wav", sample_rate=16000)) +finally: + client.disconnect(terminate=True) +``` + +
- Stream your microphone in real-time + Stream your microphone (sync) + +`MicrophoneStream` requires PyAudio: ```bash -pip install -U assemblyai +pip install -U "assemblyai[extras]" ``` ```python -import logging -from typing import Type - import assemblyai as aai from assemblyai.streaming.v3 import ( - BeginEvent, - StreamingClient, - StreamingClientOptions, - StreamingError, - StreamingEvents, - StreamingParameters, - TurnEvent, - TerminationEvent, + StreamingClient, StreamingClientOptions, StreamingEvents, StreamingParameters, ) -api_key = "" +def on_turn(client, event): + print(f"{event.transcript} (end_of_turn={event.end_of_turn})") -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) +client = StreamingClient(StreamingClientOptions(api_key="")) +client.on(StreamingEvents.Turn, on_turn) +client.connect(StreamingParameters(sample_rate=16000, speech_model="u3-rt-pro")) -def on_begin(self: Type[StreamingClient], event: BeginEvent): - print(f"Session started: {event.id}") +try: + client.stream(aai.extras.MicrophoneStream(sample_rate=16000)) +finally: + client.disconnect(terminate=True) +``` + +
+ +
+ Stream a local file (async) + +`AsyncStreamingClient` mirrors `StreamingClient` with async methods. It's safe to use as an async context manager — `disconnect()` runs on block exit even if user code raises. Don't pass `extras.stream_file` directly (it uses blocking `time.sleep`); pace from an async generator instead. + +```python +import asyncio +from assemblyai.streaming.v3 import ( + AsyncStreamingClient, StreamingClientOptions, StreamingEvents, StreamingParameters, +) + +async def stream_file_async(path: str, sample_rate: int, chunk_duration: float = 0.3): + bytes_per_chunk = int(sample_rate * chunk_duration) * 2 + with open(path, "rb") as f: + while chunk := f.read(bytes_per_chunk): + yield chunk + await asyncio.sleep(chunk_duration) + +async def on_turn(client, event): + print(f"{event.transcript} (end_of_turn={event.end_of_turn})") + +async def main(): + async with AsyncStreamingClient(StreamingClientOptions(api_key="")) as client: + client.on(StreamingEvents.Turn, on_turn) + await client.connect(StreamingParameters( + sample_rate=16000, speech_model="u3-rt-pro", format_turns=True, + )) + await client.stream(stream_file_async("audio.wav", 16000)) + +asyncio.run(main()) +``` + +
-def on_turn(self: Type[StreamingClient], event: TurnEvent): - print(f"{event.transcript} ({event.end_of_turn})") - -def on_terminated(self: Type[StreamingClient], event: TerminationEvent): - print( - f"Session terminated: {event.audio_duration_seconds} seconds of audio processed" - ) - -def on_error(self: Type[StreamingClient], error: StreamingError): - print(f"Error occurred: {error}") - -def main(): - client = StreamingClient( - StreamingClientOptions( - api_key=api_key, - api_host="streaming.assemblyai.com", - ) - ) - - client.on(StreamingEvents.Begin, on_begin) - client.on(StreamingEvents.Turn, on_turn) - client.on(StreamingEvents.Termination, on_terminated) - client.on(StreamingEvents.Error, on_error) - - client.connect( - StreamingParameters( - sample_rate=16000, - speech_model="u3-rt-pro", - ) - ) - - try: - client.stream( - aai.extras.MicrophoneStream(sample_rate=16000) - ) - finally: - client.disconnect(terminate=True) - -if __name__ == "__main__": - main() +
+ Handle errors + +Server-side errors arrive on the `Error` event rather than being raised. The handler receives a `StreamingError` (an `Exception` subclass) with `.code: int | None` — **not** the wire `ErrorEvent` class. + +`StreamingErrorCodes` is a `dict[int, str]` mapping wire codes to human-readable messages. Use `.get(...)` for lookup: + +```python +from assemblyai.streaming.v3 import StreamingErrorCodes + +def on_error(client, error): + message = StreamingErrorCodes.get(error.code, str(error)) + print(f"Streaming error {error.code}: {message}") +``` + +Common codes: `4001` Not Authorized, `4002` Insufficient Funds, `4029` Client sent audio too fast, `4031` Session idle for too long. + +
+ +
+ Change settings mid-session + +`set_params` updates an active session. Typical use: enable turn formatting (punctuation, casing) only on confirmed end-of-turn so partial transcripts stay raw: + +```python +from assemblyai.streaming.v3 import StreamingSessionParameters + +def on_turn(client, event): + if event.end_of_turn and not event.turn_is_formatted: + client.set_params(StreamingSessionParameters(format_turns=True)) +``` + +For voice agents, `force_endpoint()` flushes the current turn — useful when an external signal (UI button, barge-in detection) determines the user has stopped speaking before VAD does: + +```python +client.force_endpoint() # ends the current turn immediately +``` + +
+ +
+ Temporary tokens for browser / edge clients + +Don't ship your API key to browsers. Mint a short-lived token server-side and pass it to the client. + +**Sync server (Flask / WSGI / scripts):** +```python +client = StreamingClient(StreamingClientOptions(api_key="")) +token = client.create_temporary_token(expires_in_seconds=60) +# Send `token` to the browser, which connects with options(token=token). +``` + +**Async server (FastAPI / asyncio):** always wrap in `async with` even though you don't call `connect()` — `create_temporary_token` lazily opens an `httpx.AsyncClient` pool. The context manager closes it on exit; without it you leak a pool every request. + +```python +from fastapi import FastAPI +from assemblyai.streaming.v3 import AsyncStreamingClient, StreamingClientOptions + +app = FastAPI() +MASTER_KEY = "" + +@app.get("/streaming-token") +async def streaming_token(): + async with AsyncStreamingClient(StreamingClientOptions(api_key=MASTER_KEY)) as client: + return {"token": await client.create_temporary_token(expires_in_seconds=60)} +``` + +**Browser / edge client:** pass the token via `StreamingClientOptions(token=...)`: + +```python +client = StreamingClient(StreamingClientOptions(token="")) +client.connect(StreamingParameters(sample_rate=16000, speech_model="u3-rt-pro")) ```
diff --git a/assemblyai/__init__.py b/assemblyai/__init__.py index 77efb71..4662522 100644 --- a/assemblyai/__init__.py +++ b/assemblyai/__init__.py @@ -2,10 +2,9 @@ from .__version__ import __version__ from .client import Client from .lemur import Lemur -from .transcriber import RealtimeTranscriber, Transcriber, Transcript, TranscriptGroup +from .transcriber import Transcriber, Transcript, TranscriptGroup from .types import ( AssemblyAIError, - AudioEncoding, AutohighlightResponse, AutohighlightResult, Chapter, @@ -47,13 +46,6 @@ PIIRedactionPolicy, PIISubstitutionPolicy, RawTranscriptionConfig, - RealtimeError, - RealtimeFinalTranscript, - RealtimePartialTranscript, - RealtimeSessionInformation, - RealtimeSessionOpened, - RealtimeTranscript, - RealtimeWord, RedactPiiAudioOptions, Sentence, Sentiment, @@ -93,7 +85,6 @@ __all__ = [ # types "AssemblyAIError", - "AudioEncoding", "AutohighlightResponse", "AutohighlightResult", "Chapter", @@ -170,14 +161,6 @@ "Word", "WordBoost", "WordSearchMatch", - "RealtimeTranscriber", - "RealtimeError", - "RealtimeFinalTranscript", - "RealtimePartialTranscript", - "RealtimeSessionInformation", - "RealtimeSessionOpened", - "RealtimeTranscript", - "RealtimeWord", # package globals "settings", # packages diff --git a/assemblyai/__version__.py b/assemblyai/__version__.py index 8261441..79afd54 100644 --- a/assemblyai/__version__.py +++ b/assemblyai/__version__.py @@ -1 +1 @@ -__version__ = "0.64.2" +__version__ = "0.64.3" diff --git a/assemblyai/api.py b/assemblyai/api.py index b2f666a..7c20645 100644 --- a/assemblyai/api.py +++ b/assemblyai/api.py @@ -9,8 +9,6 @@ ENDPOINT_UPLOAD = "/v2/upload" ENDPOINT_LEMUR_BASE = "/lemur/v3" ENDPOINT_LEMUR = f"{ENDPOINT_LEMUR_BASE}/generate" -ENDPOINT_REALTIME_WEBSOCKET = "/v2/realtime/ws" -ENDPOINT_REALTIME_TOKEN = "/v2/realtime/token" def _get_error_message(response: httpx.Response) -> str: @@ -415,24 +413,3 @@ def lemur_get_response_data( return types.LemurQuestionResponse.parse_obj(json_data) return types.LemurStringResponse.parse_obj(json_data) - - -def create_temporary_token( - client: httpx.Client, - request: types.RealtimeCreateTemporaryTokenRequest, - http_timeout: Optional[float], -) -> str: - response = client.post( - f"{ENDPOINT_REALTIME_TOKEN}", - json=request.dict(exclude_none=True), - timeout=http_timeout, - ) - - if response.status_code != httpx.codes.OK: - raise types.AssemblyAIError( - f"Failed to create temporary token: {_get_error_message(response)}", - response.status_code, - ) - - data = types.RealtimeCreateTemporaryTokenResponse.parse_obj(response.json()) - return data.token diff --git a/assemblyai/streaming/v3/__init__.py b/assemblyai/streaming/v3/__init__.py index e89ad55..c7d0806 100644 --- a/assemblyai/streaming/v3/__init__.py +++ b/assemblyai/streaming/v3/__init__.py @@ -1,3 +1,4 @@ +from .async_client import AsyncStreamingClient from .client import StreamingClient from .models import ( BeginEvent, @@ -9,6 +10,7 @@ SpeechStartedEvent, StreamingClientOptions, StreamingError, + StreamingErrorCodes, StreamingEvents, StreamingParameters, StreamingPiiPolicy, @@ -21,6 +23,7 @@ ) __all__ = [ + "AsyncStreamingClient", "BeginEvent", "Encoding", "EventMessage", @@ -31,6 +34,7 @@ "StreamingClient", "StreamingClientOptions", "StreamingError", + "StreamingErrorCodes", "StreamingEvents", "StreamingParameters", "StreamingPiiPolicy", diff --git a/assemblyai/streaming/v3/_base.py b/assemblyai/streaming/v3/_base.py new file mode 100644 index 0000000..d46f90a --- /dev/null +++ b/assemblyai/streaming/v3/_base.py @@ -0,0 +1,267 @@ +"""Sync/async-agnostic core for streaming v3 clients. + +Houses the pieces that are *exactly* the same between the threaded +``StreamingClient`` and the asyncio-based ``AsyncStreamingClient``: + +- Wire-format helpers (``_dump_model``, ``_parse_model``, ``_build_uri``, + ``_build_headers``, parameter normalization, user-agent construction). +- Inbound message parsing (``_parse_message`` + ``_parse_event_type``). +- Connection-closed error mapping (``_build_connection_closed_error``). +- The ``_BaseStreamingClient`` base class with shared init state and + the ``on(...)`` handler-registration entrypoint. + +Subclasses must implement the I/O loops (``_read_*`` / ``_write_*``) plus +``connect``, ``disconnect``, ``stream``, ``set_params``, ``force_endpoint``, +and ``create_temporary_token``. Sync subclasses use plain methods; async +subclasses use ``async def``. The sync/async return-type divergence is +why those methods aren't ``@abstractmethod`` on this base. +""" + +import json +import logging +import sys +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union +from urllib.parse import urlencode + +import websockets +from pydantic import BaseModel + +from assemblyai import __version__ + +from .models import ( + BeginEvent, + ErrorEvent, + EventMessage, + LLMGatewayResponseEvent, + SpeechStartedEvent, + StreamingClientOptions, + StreamingError, + StreamingErrorCodes, + StreamingEvents, + StreamingParameters, + TerminationEvent, + TurnEvent, + WarningEvent, +) + +logger = logging.getLogger(__name__) + + +_M = TypeVar("_M", bound=BaseModel) + + +def _dump_model(model: BaseModel) -> Dict[str, Any]: + if hasattr(model, "model_dump"): + return model.model_dump(exclude_none=True) + return model.dict(exclude_none=True) + + +def _dump_model_json(model: BaseModel) -> str: + if hasattr(model, "model_dump_json"): + return model.model_dump_json(exclude_none=True) + return model.json(exclude_none=True) + + +def _parse_model(model_class: Type[_M], data: Dict[str, Any]) -> _M: + if hasattr(model_class, "model_validate"): + return model_class.model_validate(data) + return model_class.parse_obj(data) + + +def _normalize_min_turn_silence(params_dict: dict) -> dict: + """Collapse `min_end_of_turn_silence_when_confident` into `min_turn_silence` so only + one wire key is ever sent. Emits deprecation warnings.""" + old = params_dict.pop("min_end_of_turn_silence_when_confident", None) + if old is None: + return params_dict + if "min_turn_silence" in params_dict: + logger.warning( + "[Deprecation Warning] Both `min_end_of_turn_silence_when_confident` and " + "`min_turn_silence` are set. Using `min_turn_silence`; " + "`min_end_of_turn_silence_when_confident` is deprecated." + ) + else: + logger.warning( + "[Deprecation Warning] `min_end_of_turn_silence_when_confident` is " + "deprecated and will be removed in a future release. Please use " + "`min_turn_silence` instead." + ) + params_dict["min_turn_silence"] = old + return params_dict + + +def _normalize_voice_focus(params_dict: dict) -> dict: + """Collapse `noise_suppression_model` / `noise_suppression_threshold` into + `voice_focus` / `voice_focus_threshold` so only the new wire keys are sent. + Emits deprecation warnings.""" + for old_key, new_key in ( + ("noise_suppression_model", "voice_focus"), + ("noise_suppression_threshold", "voice_focus_threshold"), + ): + old = params_dict.pop(old_key, None) + if old is None: + continue + if new_key in params_dict: + logger.warning( + f"[Deprecation Warning] Both `{old_key}` and `{new_key}` are set. " + f"Using `{new_key}`; `{old_key}` is deprecated." + ) + else: + logger.warning( + f"[Deprecation Warning] `{old_key}` is deprecated and will be removed " + f"in a future release. Please use `{new_key}` instead." + ) + params_dict[new_key] = old + return params_dict + + +def _user_agent() -> str: + vi = sys.version_info + python_version = f"{vi.major}.{vi.minor}.{vi.micro}" + return ( + f"AssemblyAI/1.0 (sdk=Python/{__version__} runtime_env=Python/{python_version})" + ) + + +def _emit_param_warnings(params: StreamingParameters) -> None: + if params.speech_model == "u3-pro": + logger.warning( + "[Deprecation Warning] The speech model `u3-pro` is deprecated and will be removed in a future release. " + "Please use `u3-rt-pro` instead." + ) + if params.customer_support_audio_capture: + logger.warning( + "`customer_support_audio_capture=True` will record session audio. " + "Only enable this when explicitly coordinating with AssemblyAI support." + ) + + +def _build_uri(host: str, params: StreamingParameters) -> str: + params_dict = _normalize_voice_focus( + _normalize_min_turn_silence(_dump_model(params)) + ) + # JSON-encode list and dict parameters for proper API compatibility (e.g., + # keyterms_prompt, llm_gateway) + for key, value in params_dict.items(): + if isinstance(value, list): + params_dict[key] = json.dumps(value) + elif isinstance(value, dict): + params_dict[key] = json.dumps(value) + + params_encoded = urlencode(params_dict) + + if host.startswith(("ws://", "wss://")): + return f"{host}/v3/ws?{params_encoded}" + return f"wss://{host}/v3/ws?{params_encoded}" + + +def _build_headers(options: StreamingClientOptions) -> Dict[str, Optional[str]]: + # Matches the pre-refactor sync behavior: ``Authorization`` is left as the + # raw value (may be ``None`` when neither ``token`` nor ``api_key`` is set, + # which surfaces the misconfiguration through the websockets/httpx layer). + return { + "Authorization": options.token or options.api_key, + "User-Agent": _user_agent(), + "AssemblyAI-Version": "2025-05-12", + } + + +class _BaseStreamingClient: + """Sync/async-agnostic core for streaming clients. + + Subclasses must implement: ``connect``, ``disconnect``, ``stream``, + ``set_params``, ``force_endpoint``, ``create_temporary_token``, plus + the I/O loops (``_read_*`` / ``_write_*``). Sync subclasses use plain + methods; async subclasses use ``async def`` — the return-type + divergence is why these aren't ``@abstractmethod`` on this base. + """ + + def __init__(self, options: StreamingClientOptions): + self._options = options + self._handlers: Dict[StreamingEvents, List[Callable]] = { + event: [] for event in StreamingEvents.__members__.values() + } + # Dedup flags for one-time error dispatch. ``_report_connection_closed`` + # and ``_report_server_error`` perform their flag check + set + # synchronously (no ``await`` / yield between them) before any + # dispatch, so even when both I/O tasks/threads race to report the + # same close only the first caller executes the dispatch body. + # - Threading: the read thread is the sole dispatcher; the write + # thread stages closes via ``_pending_close_error`` for the read + # thread to drain. + # - Asyncio: either task may call the report function; the sync + # check-and-set inside the function gives the dedup atomicity. + self._connection_closed_reported = False + self._server_error_reported = False + self._websocket: Optional[Any] = None + + def on(self, event: StreamingEvents, handler: Callable) -> None: + """Register a handler for a streaming event. + + ``event`` is a value from ``StreamingEvents`` (``Begin``, ``Turn``, + ``Termination``, ``SpeechStarted``, ``Error``, ``Warning``, + ``LLMGatewayResponse``). ``handler`` is invoked as + ``handler(client, event)``. For ``AsyncStreamingClient``, async + handlers are awaited inline on the read task. Exceptions raised by + handlers are logged and swallowed — they do not terminate the + session. + """ + if event in StreamingEvents.__members__.values() and callable(handler): + self._handlers[event].append(handler) + + @staticmethod + def _parse_event_type(message_type: Optional[Any]) -> Optional[StreamingEvents]: + if not isinstance(message_type, str): + return None + try: + return StreamingEvents[message_type] + except KeyError: + return None + + @classmethod + def _parse_message(cls, data: Dict[str, Any]) -> Optional[EventMessage]: + if "type" in data: + event_type = cls._parse_event_type(data.get("type")) + + if event_type == StreamingEvents.Begin: + return _parse_model(BeginEvent, data) + elif event_type == StreamingEvents.Termination: + return _parse_model(TerminationEvent, data) + elif event_type == StreamingEvents.Turn: + return _parse_model(TurnEvent, data) + elif event_type == StreamingEvents.SpeechStarted: + return _parse_model(SpeechStartedEvent, data) + elif event_type == StreamingEvents.LLMGatewayResponse: + return _parse_model(LLMGatewayResponseEvent, data) + elif event_type == StreamingEvents.Error: + return _parse_model(ErrorEvent, data) + elif event_type == StreamingEvents.Warning: + return _parse_model(WarningEvent, data) + else: + return None + elif "error" in data: + return _parse_model(ErrorEvent, data) + return None + + @staticmethod + def _build_connection_closed_error( + error: Union[ + StreamingError, + ErrorEvent, + websockets.exceptions.ConnectionClosed, + OSError, + ], + ) -> Optional[StreamingError]: + if isinstance(error, StreamingError): + return error + if isinstance(error, ErrorEvent): + return StreamingError(message=error.error, code=error.error_code) + if isinstance(error, websockets.exceptions.ConnectionClosed): + if error.code == 1000: + return None + if error.code is not None and error.code in StreamingErrorCodes: + message = StreamingErrorCodes[error.code] + else: + message = error.reason or f"Connection closed (code={error.code})" + return StreamingError(message=message, code=error.code) + return StreamingError(message=f"Connection failed: {error}") diff --git a/assemblyai/streaming/v3/async_client.py b/assemblyai/streaming/v3/async_client.py new file mode 100644 index 0000000..6804982 --- /dev/null +++ b/assemblyai/streaming/v3/async_client.py @@ -0,0 +1,512 @@ +import asyncio +import collections.abc +import inspect +import json +import logging +from typing import Any, AsyncIterable, Callable, Dict, Iterable, Optional, Union + +import httpx +import websockets +from pydantic import BaseModel + +# Prefer the new asyncio client API (websockets >= 13). Fall back to the legacy +# top-level connect for older versions the SDK still supports per ``setup.py`` +# (``websockets>=11.0``). The two APIs differ only in the header-kwarg name +# (``additional_headers`` vs ``extra_headers``); the ``websocket_connect_async`` +# wrapper below papers that over so tests and callers see one entry point. +try: + from websockets.asyncio.client import connect as _ws_connect + + _WS_HEADER_KW = "additional_headers" +except ImportError: # pragma: no cover - exercised on websockets <13 only + from websockets.client import connect as _ws_connect # type: ignore[no-redef] + + _WS_HEADER_KW = "extra_headers" + +from ._base import ( + _BaseStreamingClient, + _build_headers, + _build_uri, + _dump_model, + _dump_model_json, + _emit_param_warnings, + _normalize_min_turn_silence, + _user_agent, +) +from .models import ( + ErrorEvent, + EventMessage, + ForceEndpoint, + OperationMessage, + StreamingClientOptions, + StreamingError, + StreamingEvents, + StreamingParameters, + StreamingSessionParameters, + TerminateSession, + TerminationEvent, + UpdateConfiguration, + WarningEvent, +) + +logger = logging.getLogger(__name__) + + +def websocket_connect_async( + uri: str, additional_headers: Dict[str, Optional[str]] +) -> Any: + """Open a websocket connection using whichever ``websockets`` API is + available. Returns the underlying ``Connect`` awaitable so callers may + ``await`` it directly (or wrap in ``asyncio.wait_for``). Module-level + indirection so tests can patch a single attribute. + + ``additional_headers`` matches the ``Dict[str, Optional[str]]`` shape + returned by ``_build_headers``; an ``Authorization`` value of ``None`` + (no credentials configured) is forwarded to the underlying websockets + library so the misconfiguration surfaces at the handshake layer. + """ + return _ws_connect(uri, **{_WS_HEADER_KW: additional_headers}) + + +class AsyncStreamingClient(_BaseStreamingClient): + """Asyncio-native counterpart to ``StreamingClient``. + + The public API mirrors the thread-based client one-to-one — same options, + parameters, events, and event-handler registration. Methods that touch the + network are coroutines. Event handlers may be plain callables or + coroutine functions; coroutine handlers are awaited inline by the single + internal read task. Handlers should therefore avoid indefinite blocking, + just as with the sync client. + + Behavioral notes vs. the sync ``StreamingClient``: + + - ``stream`` / ``set_params`` / ``force_endpoint`` raise ``RuntimeError`` + when called before ``connect()`` — silent drop would diverge from the + sync client (which buffers pre-connect data) in a way that's easy to + miss. After the connection has closed, the same calls are silent + no-ops so cleanup paths don't need defensive try/except. + - ``disconnect(terminate=True)`` waits at most 2.0s for the write task to + drain the ``TerminateSession`` frame before forcing teardown. The sync + client joins indefinitely. + - Supports ``async with``: ``disconnect()`` is invoked on block exit so + the websocket / HTTP client are always released even when user code + raises. + """ + + def __init__(self, options: StreamingClientOptions): + super().__init__(options) + + self._client = _AsyncHTTPClient( + api_host=options.api_host, api_key=options.api_key + ) + + # Created lazily in ``connect()`` so they bind to the loop that runs + # ``connect()``, not whatever loop was current at ``__init__`` time + # (matters on Python 3.8/3.9 and avoids "no running event loop" + # DeprecationWarnings on 3.10+ when constructed outside a loop). + self._write_queue: Optional["asyncio.Queue[OperationMessage]"] = None + self._stop_event: Optional[asyncio.Event] = None + self._read_task: Optional[asyncio.Task] = None + self._write_task: Optional[asyncio.Task] = None + + async def connect(self, params: StreamingParameters) -> None: + # Single-use: a client whose connection went down (success or + # handshake failure) sets ``_connection_closed_reported``; reusing + # it would yield a silently dead read/write loop because + # ``_stop_event`` is already set. + already_used = ( + self._websocket is not None + or self._connection_closed_reported + or (self._read_task is not None and not self._read_task.done()) + ) + if already_used: + raise RuntimeError( + "AsyncStreamingClient has already been connected; " + "create a new instance for a new connection." + ) + + self._write_queue = asyncio.Queue() + self._stop_event = asyncio.Event() + + _emit_param_warnings(params) + + uri = _build_uri(self._options.api_host, params) + headers = _build_headers(self._options) + + try: + self._websocket = await asyncio.wait_for( + websocket_connect_async(uri, additional_headers=headers), + timeout=15, + ) + except websockets.exceptions.InvalidStatus as exc: + status_code = getattr(getattr(exc, "response", None), "status_code", None) + await self._report_connection_closed( + StreamingError( + message=f"WebSocket handshake rejected (HTTP {status_code})", + code=status_code, + ) + ) + # Single-use design: a failed handshake terminates the client. Close + # the HTTP client now so users who treat ``on_error`` as the + # terminal signal don't leak the httpx pool. + await self._client.aclose() + return + except ( + websockets.exceptions.InvalidHandshake, + websockets.exceptions.ConnectionClosed, + OSError, + asyncio.TimeoutError, + TimeoutError, + ) as exc: + await self._report_connection_closed(exc) + await self._client.aclose() + return + + self._read_task = asyncio.create_task( + self._read_loop(), name="AsyncStreamingClient._read_loop" + ) + self._write_task = asyncio.create_task( + self._write_loop(), name="AsyncStreamingClient._write_loop" + ) + + logger.debug("Connected to WebSocket server") + + async def disconnect(self, terminate: bool = False) -> None: + if self._stop_event is None: + # Never connected — still close the HTTP client so the pool + # doesn't leak. + await self._client.aclose() + return + + # Enqueue Terminate even when stop is already set: ``_write_loop`` + # bypasses the stop gate for TerminateSession so the frame still + # reaches the server when the write task is alive. + if terminate and self._write_queue is not None: + await self._write_queue.put(TerminateSession()) + # Let the write task drain TerminateSession and exit naturally + # before we set stop / cancel below. ``asyncio.wait`` does not + # cancel the awaited task on timeout, unlike ``wait_for``. + if self._write_task is not None and not self._write_task.done(): + await asyncio.wait({self._write_task}, timeout=2.0) + + self._stop_event.set() + + current = asyncio.current_task() + for task in (self._read_task, self._write_task): + if task is None or task is current or task.done(): + continue + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + except Exception: + logger.exception("Streaming task raised during disconnect") + + await self._close_websocket() + await self._client.aclose() + + async def _close_websocket(self) -> None: + if not self._websocket: + return + try: + await self._websocket.close() + except (OSError, websockets.exceptions.WebSocketException) as exc: + logger.debug("Error closing websocket: %s", exc) + + async def stream( + self, + data: Union[bytes, AsyncIterable[bytes], Iterable[bytes]], + ) -> None: + # Loud on misuse (pre-connect), quiet on natural close (post-stop). + # The first guards against silent data loss; the second keeps cleanup + # paths simple. + write_queue, stop_event = self._ensure_connected("stream") + if stop_event.is_set(): + return + + if isinstance(data, bytes): + await write_queue.put(data) + return + + if isinstance(data, collections.abc.AsyncIterable): + async for chunk in data: + if stop_event.is_set(): + return + await write_queue.put(chunk) + return + + for chunk in data: + if stop_event.is_set(): + return + await write_queue.put(chunk) + + async def set_params(self, params: StreamingSessionParameters) -> None: + write_queue, stop_event = self._ensure_connected("set_params") + if stop_event.is_set(): + return + message_dict = _normalize_min_turn_silence(_dump_model(params)) + message = UpdateConfiguration(**message_dict) + await write_queue.put(message) + + async def force_endpoint(self) -> None: + write_queue, stop_event = self._ensure_connected("force_endpoint") + if stop_event.is_set(): + return + await write_queue.put(ForceEndpoint()) + + def _ensure_connected( + self, method: str + ) -> "tuple[asyncio.Queue[OperationMessage], asyncio.Event]": + # Returns the post-connect primitives so callers narrow ``Optional`` + # locally instead of repeating ``is None`` checks at every use site + # (mypy can't propagate narrowing through a separate method call). + if self._write_queue is None or self._stop_event is None: + raise RuntimeError( + f"AsyncStreamingClient is not connected; call connect() before {method}()" + ) + return self._write_queue, self._stop_event + + async def _write_loop(self) -> None: + # ``_write_loop`` is only ``create_task``ed inside ``connect()`` after + # the primitives are initialized. ``if`` (not ``assert``) so it + # survives ``python -O`` if the invariant is ever violated. + if self._write_queue is None or self._stop_event is None: + raise RuntimeError("AsyncStreamingClient internal state not initialized") + while True: + if not self._websocket: + raise ValueError("Not connected to the WebSocket server") + + try: + data = await asyncio.wait_for(self._write_queue.get(), timeout=1) + except asyncio.TimeoutError: + if self._stop_event.is_set(): + return + continue + + # TerminateSession bypasses the stop gate so disconnect(terminate=True) + # can always send it, even when stop is set between put() and the + # write loop's next iteration. + is_terminate = isinstance(data, TerminateSession) + if not is_terminate and self._stop_event.is_set(): + return + + try: + if isinstance(data, bytes): + await self._websocket.send(data) + elif isinstance(data, BaseModel): + await self._websocket.send(_dump_model_json(data)) + else: + raise ValueError(f"Attempted to send invalid message: {type(data)}") + except websockets.exceptions.ConnectionClosed as exc: + # Dispatch the close directly from the write task. The read + # task may short-circuit on ``_stop_event`` at the top of its + # loop (e.g. while a buffered message was processed between + # ``recv()`` calls) and never observe the close in ``recv()``, + # so the write task can't rely on it to dispatch. + # ``_report_connection_closed`` is idempotent — its flag check + # + set is synchronous (no ``await`` between them), so if the + # read task also raises ``ConnectionClosed`` it'll be a no-op. + await self._report_connection_closed(exc) + return + + if is_terminate: + return + + async def _read_loop(self) -> None: + # ``_read_loop`` is only ``create_task``ed inside ``connect()`` after + # ``_stop_event`` is initialized. ``if`` (not ``assert``) so it + # survives ``python -O`` if the invariant is ever violated. + if self._stop_event is None: + raise RuntimeError("AsyncStreamingClient internal state not initialized") + while True: + if not self._websocket: + raise ValueError("Not connected to the WebSocket server") + + if self._stop_event.is_set(): + return + + try: + message_data = await self._websocket.recv() + except websockets.exceptions.ConnectionClosed as exc: + await self._report_connection_closed(exc) + return + + try: + message_json = json.loads(message_data) + except json.JSONDecodeError as exc: + logger.warning(f"Failed to decode message: {exc}") + continue + + message = self._parse_message(message_json) + + if isinstance(message, ErrorEvent): + await self._report_server_error(message) + elif isinstance(message, WarningEvent): + await self._handle_warning(message) + elif message: + await self._handle_message(message) + else: + logger.warning(f"Unsupported event type: {message_json.get('type')}") + + async def _handle_message(self, message: EventMessage) -> None: + # ``_handle_message`` is only reached from ``_read_loop``, which only + # runs after ``connect()`` has initialized ``_stop_event``. + if self._stop_event is None: + raise RuntimeError("AsyncStreamingClient internal state not initialized") + if isinstance(message, TerminationEvent): + self._stop_event.set() + + event_type = StreamingEvents[message.type] + + for handler in self._handlers[event_type]: + await self._invoke_handler(handler, message, event_type) + + async def _handle_warning(self, warning: WarningEvent) -> None: + logger.warning( + "Streaming warning (code=%s): %s", warning.warning_code, warning.warning + ) + for handler in self._handlers[StreamingEvents.Warning]: + await self._invoke_handler(handler, warning, StreamingEvents.Warning) + + async def _report_server_error(self, error: ErrorEvent) -> None: + # Only reachable from ``_read_loop`` (after primitives are initialized). + if self._stop_event is None: + raise RuntimeError("AsyncStreamingClient internal state not initialized") + self._server_error_reported = True + streaming_error = StreamingError(message=error.error, code=error.error_code) + logger.error("Streaming error: %s (code=%s)", error.error, error.error_code) + await self._dispatch_error(streaming_error) + # Tear down locally so a server that sends Error without a trailing + # close frame doesn't leave the read loop blocked in ``recv()`` + # forever. ``_close_websocket`` is idempotent; if the trailing close + # does arrive, ``_report_connection_closed`` will dedup via + # ``_server_error_reported``. + await self._close_websocket() + self._stop_event.set() + + async def _report_connection_closed( + self, + error: Union[ + StreamingError, + ErrorEvent, + websockets.exceptions.ConnectionClosed, + OSError, + ], + ) -> None: + # Callers (``connect()`` failure path, ``_read_loop``, ``_write_loop``) + # all run after ``_stop_event`` is initialized. + if self._stop_event is None: + raise RuntimeError("AsyncStreamingClient internal state not initialized") + if self._connection_closed_reported: + return + self._connection_closed_reported = True + self._stop_event.set() + + streaming_error = self._build_connection_closed_error(error) + + if streaming_error is None: + await self._close_websocket() + return + + if isinstance(error, websockets.exceptions.ConnectionClosed): + reason = error.reason or "no reason given" + logger.error("Connection closed: %s (code=%s)", reason, error.code) + else: + logger.error( + "Connection failed: %s (code=%s)", + streaming_error, + streaming_error.code, + ) + + # If a server Error frame already fired on_error, the close is the + # effect, not a new cause — log it (above) but skip the duplicate + # user-visible error. + if not self._server_error_reported: + await self._dispatch_error(streaming_error) + + await self._close_websocket() + + async def _dispatch_error(self, error: StreamingError) -> None: + for handler in self._handlers[StreamingEvents.Error]: + await self._invoke_handler(handler, error, StreamingEvents.Error) + + async def _invoke_handler( + self, + handler: Callable, + payload: Any, + event_type: StreamingEvents, + ) -> None: + try: + result = handler(self, payload) + if inspect.isawaitable(result): + await result + except Exception: + logger.exception("on_%s handler raised", event_type.name.lower()) + + async def create_temporary_token( + self, + expires_in_seconds: int, + max_session_duration_seconds: Optional[int] = None, + ) -> str: + return await self._client.create_temporary_token( + expires_in_seconds=expires_in_seconds, + max_session_duration_seconds=max_session_duration_seconds, + ) + + async def __aenter__(self) -> "AsyncStreamingClient": + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.disconnect(terminate=exc_type is None) + + +class _AsyncHTTPClient: + def __init__(self, api_host: str, api_key: Optional[str] = None): + # Lazy: don't instantiate httpx.AsyncClient here. Bare construction of + # an AsyncStreamingClient that's never connected (or used only for + # connect() — which doesn't go through the HTTP client) must not + # leak an httpx pool. + self._api_host = api_host + self._api_key = api_key + self._http_client: Optional[httpx.AsyncClient] = None + self._closed = False + + def _get_or_create_client(self) -> httpx.AsyncClient: + if self._http_client is None: + headers = {"User-Agent": f"{httpx._client.USER_AGENT} {_user_agent()}"} + if self._api_key: + headers["Authorization"] = self._api_key + self._http_client = httpx.AsyncClient( + base_url="https://" + self._api_host, + headers=headers, + ) + return self._http_client + + async def create_temporary_token( + self, + expires_in_seconds: int, + max_session_duration_seconds: Optional[int] = None, + ) -> str: + # ``expires_in_seconds`` is required per the type; always forward it + # so passing ``0`` reaches the server (where it can be validated) + # instead of being silently dropped by a falsy check. + params: Dict[str, Any] = {"expires_in_seconds": expires_in_seconds} + + if max_session_duration_seconds is not None: + params["max_session_duration_seconds"] = max_session_duration_seconds + + response = await self._get_or_create_client().get("/v3/token", params=params) + response.raise_for_status() + return response.json()["token"] + + async def aclose(self) -> None: + if self._closed: + return + self._closed = True + if self._http_client is None: + return + try: + await self._http_client.aclose() + except Exception as exc: + logger.debug("Error closing async HTTP client: %s", exc) diff --git a/assemblyai/streaming/v3/client.py b/assemblyai/streaming/v3/client.py index cce42d8..65ace7a 100644 --- a/assemblyai/streaming/v3/client.py +++ b/assemblyai/streaming/v3/client.py @@ -1,35 +1,36 @@ import json import logging import queue -import sys import threading -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union -from urllib.parse import urlencode +from typing import Any, Dict, Generator, Iterable, Optional, Union import httpx import websockets from pydantic import BaseModel from websockets.sync.client import connect as websocket_connect -from assemblyai import __version__ - +from ._base import ( + _BaseStreamingClient, + _build_headers, + _build_uri, + _dump_model, + _dump_model_json, + _emit_param_warnings, + _normalize_min_turn_silence, + _user_agent, +) from .models import ( - BeginEvent, ErrorEvent, EventMessage, ForceEndpoint, - LLMGatewayResponseEvent, OperationMessage, - SpeechStartedEvent, StreamingClientOptions, StreamingError, - StreamingErrorCodes, StreamingEvents, StreamingParameters, StreamingSessionParameters, TerminateSession, TerminationEvent, - TurnEvent, UpdateConfiguration, WarningEvent, ) @@ -37,144 +38,34 @@ logger = logging.getLogger(__name__) -def _dump_model(model: BaseModel): - if hasattr(model, "model_dump"): - return model.model_dump(exclude_none=True) - return model.dict(exclude_none=True) - - -def _parse_model(model_class, data): - if hasattr(model_class, "model_validate"): - return model_class.model_validate(data) - return model_class.parse_obj(data) - - -def _normalize_min_turn_silence(params_dict: dict) -> dict: - """Collapse `min_end_of_turn_silence_when_confident` into `min_turn_silence` so only - one wire key is ever sent. Emits deprecation warnings.""" - old = params_dict.pop("min_end_of_turn_silence_when_confident", None) - if old is None: - return params_dict - if "min_turn_silence" in params_dict: - logger.warning( - "[Deprecation Warning] Both `min_end_of_turn_silence_when_confident` and " - "`min_turn_silence` are set. Using `min_turn_silence`; " - "`min_end_of_turn_silence_when_confident` is deprecated." - ) - else: - logger.warning( - "[Deprecation Warning] `min_end_of_turn_silence_when_confident` is " - "deprecated and will be removed in a future release. Please use " - "`min_turn_silence` instead." - ) - params_dict["min_turn_silence"] = old - return params_dict - - -def _normalize_voice_focus(params_dict: dict) -> dict: - """Collapse `noise_suppression_model` / `noise_suppression_threshold` into - `voice_focus` / `voice_focus_threshold` so only the new wire keys are sent. - Emits deprecation warnings.""" - for old_key, new_key in ( - ("noise_suppression_model", "voice_focus"), - ("noise_suppression_threshold", "voice_focus_threshold"), - ): - old = params_dict.pop(old_key, None) - if old is None: - continue - if new_key in params_dict: - logger.warning( - f"[Deprecation Warning] Both `{old_key}` and `{new_key}` are set. " - f"Using `{new_key}`; `{old_key}` is deprecated." - ) - else: - logger.warning( - f"[Deprecation Warning] `{old_key}` is deprecated and will be removed " - f"in a future release. Please use `{new_key}` instead." - ) - params_dict[new_key] = old - return params_dict - - -def _dump_model_json(model: BaseModel): - if hasattr(model, "model_dump_json"): - return model.model_dump_json(exclude_none=True) - return model.json(exclude_none=True) - - -def _user_agent() -> str: - vi = sys.version_info - python_version = f"{vi.major}.{vi.minor}.{vi.micro}" - return ( - f"AssemblyAI/1.0 (sdk=Python/{__version__} runtime_env=Python/{python_version})" - ) - - -class StreamingClient: +class StreamingClient(_BaseStreamingClient): def __init__(self, options: StreamingClientOptions): - self._options = options + super().__init__(options) self._client = _HTTPClient(api_host=options.api_host, api_key=options.api_key) - self._handlers: Dict[StreamingEvents, List[Callable]] = {} - - for event in StreamingEvents.__members__.values(): - self._handlers[event] = [] - self._write_queue: queue.Queue[OperationMessage] = queue.Queue() self._write_thread = threading.Thread(target=self._write_message) self._read_thread = threading.Thread(target=self._read_message) self._stop_event = threading.Event() - # Both flags are read and set only on the read thread (or on the main - # thread before workers start, for handshake errors). Plain bools are - # sufficient — no cross-thread synchronization is needed. - self._connection_closed_reported = False - self._server_error_reported = False # Deliberate single-slot shared-memory handoff: the write thread parks # a ConnectionClosed here and the read thread drains it. Synchronization # is provided by `_stop_event.set()` (write side) + `recv(timeout=1)` # (read side), which together give a happens-before within ~1s. self._pending_close_error: Optional[Exception] = None - self._websocket = None def connect(self, params: StreamingParameters) -> None: - if params.speech_model == "u3-pro": - logger.warning( - "[Deprecation Warning] The speech model `u3-pro` is deprecated and will be removed in a future release. " - "Please use `u3-rt-pro` instead." - ) + """Open the WebSocket session and start the read/write threads. - if params.customer_support_audio_capture: - logger.warning( - "`customer_support_audio_capture=True` will record session audio. " - "Only enable this when explicitly coordinating with AssemblyAI support." - ) + Blocks until the handshake completes. If the server rejects the + handshake (auth error, etc.) ``Error`` is dispatched to any + ``on(StreamingEvents.Error, ...)`` handler rather than raised, so + registration order matters: call ``on()`` before ``connect()``. + """ + _emit_param_warnings(params) - params_dict = _normalize_voice_focus( - _normalize_min_turn_silence(_dump_model(params)) - ) - - # JSON-encode list and dict parameters for proper API compatibility (e.g., keyterms_prompt, llm_gateway) - for key, value in params_dict.items(): - if isinstance(value, list): - params_dict[key] = json.dumps(value) - elif isinstance(value, dict): - params_dict[key] = json.dumps(value) - - params_encoded = urlencode(params_dict) - - host = self._options.api_host - if host.startswith(("ws://", "wss://")): - uri = f"{host}/v3/ws?{params_encoded}" - else: - uri = f"wss://{host}/v3/ws?{params_encoded}" - headers = { - "Authorization": self._options.token - if self._options.token - else self._options.api_key, - "User-Agent": _user_agent(), - "AssemblyAI-Version": "2025-05-12", - } + uri = _build_uri(self._options.api_host, params) + headers = _build_headers(self._options) try: self._websocket = websocket_connect( @@ -206,6 +97,14 @@ def connect(self, params: StreamingParameters) -> None: logger.debug("Connected to WebSocket server") def disconnect(self, terminate: bool = False) -> None: + """Stop the read/write threads and close the WebSocket. + + Pass ``terminate=True`` for a graceful close — the client sends a + ``TerminateSession`` frame and waits for the server's + ``TerminationEvent`` (which reports total audio duration). Without + ``terminate=True`` the WebSocket is closed without notifying the + server. + """ # Enqueue Terminate even when stop is already set: `_write_message` # bypasses the stop gate for TerminateSession so the frame still # reaches the server when the write thread is alive. @@ -236,6 +135,13 @@ def _close_websocket(self) -> None: def stream( self, data: Union[bytes, Generator[bytes, None, None], Iterable[bytes]] ) -> None: + """Send audio bytes to the server. + + Accepts a raw ``bytes`` buffer or any (sync) iterable of ``bytes``. + Returns once all chunks are enqueued — the write thread does the + actual sending. After ``disconnect()`` (or a connection drop) this + becomes a silent no-op. + """ if self._stop_event.is_set(): return @@ -257,10 +163,6 @@ def force_endpoint(self): message = ForceEndpoint() self._write_queue.put(message) - def on(self, event: StreamingEvents, handler: Callable) -> None: - if event in StreamingEvents.__members__.values() and callable(handler): - self._handlers[event].append(handler) - def _write_message(self) -> None: while True: if not self._websocket: @@ -335,7 +237,7 @@ def _read_message(self) -> None: elif message: self._handle_message(message) else: - logger.warning(f"Unsupported event type: {message_json['type']}") + logger.warning(f"Unsupported event type: {message_json.get('type')}") def _handle_message(self, message: EventMessage) -> None: if isinstance(message, TerminationEvent): @@ -349,43 +251,6 @@ def _handle_message(self, message: EventMessage) -> None: except Exception: logger.exception("on_%s handler raised", event_type.name.lower()) - def _parse_message(self, data: Dict[str, Any]) -> Optional[EventMessage]: - if "type" in data: - message_type = data.get("type") - - event_type = self._parse_event_type(message_type) - - if event_type == StreamingEvents.Begin: - return _parse_model(BeginEvent, data) - elif event_type == StreamingEvents.Termination: - return _parse_model(TerminationEvent, data) - elif event_type == StreamingEvents.Turn: - return _parse_model(TurnEvent, data) - elif event_type == StreamingEvents.SpeechStarted: - return _parse_model(SpeechStartedEvent, data) - elif event_type == StreamingEvents.LLMGatewayResponse: - return _parse_model(LLMGatewayResponseEvent, data) - elif event_type == StreamingEvents.Error: - return _parse_model(ErrorEvent, data) - elif event_type == StreamingEvents.Warning: - return _parse_model(WarningEvent, data) - else: - return None - elif "error" in data: - return _parse_model(ErrorEvent, data) - - return None - - @staticmethod - def _parse_event_type(message_type: Optional[Any]) -> Optional[StreamingEvents]: - if not isinstance(message_type, str): - return None - - try: - return StreamingEvents[message_type] - except KeyError: - return None - def _handle_warning(self, warning: WarningEvent): logger.warning( "Streaming warning (code=%s): %s", warning.warning_code, warning.warning @@ -460,29 +325,6 @@ def _dispatch_error(self, error: StreamingError) -> None: except Exception: logger.exception("on_error handler raised") - @staticmethod - def _build_connection_closed_error( - error: Union[ - StreamingError, - ErrorEvent, - websockets.exceptions.ConnectionClosed, - OSError, - ], - ) -> Optional[StreamingError]: - if isinstance(error, StreamingError): - return error - if isinstance(error, ErrorEvent): - return StreamingError(message=error.error, code=error.error_code) - if isinstance(error, websockets.exceptions.ConnectionClosed): - if error.code == 1000: - return None - if error.code is not None and error.code in StreamingErrorCodes: - message = StreamingErrorCodes[error.code] - else: - message = error.reason or f"Connection closed (code={error.code})" - return StreamingError(message=message, code=error.code) - return StreamingError(message=f"Connection failed: {error}") - def create_temporary_token( self, expires_in_seconds: int, @@ -496,11 +338,7 @@ def create_temporary_token( class _HTTPClient: def __init__(self, api_host: str, api_key: Optional[str] = None): - vi = sys.version_info - python_version = f"{vi.major}.{vi.minor}.{vi.micro}" - user_agent = f"{httpx._client.USER_AGENT} AssemblyAI/1.0 (sdk=Python/{__version__} runtime_env=Python/{python_version})" - - headers = {"User-Agent": user_agent} + headers = {"User-Agent": f"{httpx._client.USER_AGENT} {_user_agent()}"} if api_key: headers["Authorization"] = api_key @@ -515,12 +353,12 @@ def create_temporary_token( expires_in_seconds: int, max_session_duration_seconds: Optional[int] = None, ) -> str: - params: Dict[str, Any] = {} - - if expires_in_seconds: - params["expires_in_seconds"] = expires_in_seconds + # ``expires_in_seconds`` is required per the type; always forward it + # so passing ``0`` reaches the server (where it can be validated) + # instead of being silently dropped by a falsy check. + params: Dict[str, Any] = {"expires_in_seconds": expires_in_seconds} - if max_session_duration_seconds: + if max_session_duration_seconds is not None: params["max_session_duration_seconds"] = max_session_duration_seconds response = self._http_client.get( diff --git a/assemblyai/streaming/v3/models.py b/assemblyai/streaming/v3/models.py index 5b35a1c..c7cd5f4 100644 --- a/assemblyai/streaming/v3/models.py +++ b/assemblyai/streaming/v3/models.py @@ -106,6 +106,7 @@ class StreamingSessionParameters(BaseModel): filter_profanity: Optional[bool] = None prompt: Optional[str] = None interruption_delay: Optional[int] = None + turn_left_pad_ms: Optional[int] = None class Encoding(str, Enum): diff --git a/assemblyai/transcriber.py b/assemblyai/transcriber.py index ef94bac..7bf40ca 100644 --- a/assemblyai/transcriber.py +++ b/assemblyai/transcriber.py @@ -2,18 +2,11 @@ import concurrent.futures import functools -import json import os -import queue -import threading import time from typing import ( - Any, BinaryIO, - Callable, Dict, - Generator, - Iterable, Iterator, List, Optional, @@ -21,13 +14,10 @@ Tuple, Union, ) -from urllib.parse import urlencode, urlparse +from urllib.parse import urlparse import httpx -import websockets -import websockets.exceptions from typing_extensions import Self -from websockets.sync.client import connect as websocket_connect from . import api, lemur, types from . import client as _client @@ -1281,444 +1271,3 @@ def list_transcripts_async( Returns: A page with a list of transcripts along with page details. """ return self._executor.submit(self._impl.list_transcripts, params=params) - - -class _RealtimeTranscriberImpl: - def __init__( - self, - *, - on_data: Callable[[types.RealtimeTranscript], None], - on_error: Callable[[types.RealtimeError], None], - on_open: Optional[Callable[[types.RealtimeSessionOpened], None]], - on_close: Optional[Callable[[], None]], - sample_rate: int, - word_boost: List[str], - encoding: Optional[types.AudioEncoding] = None, - token: Optional[str] = None, - client: _client.Client, - end_utterance_silence_threshold: Optional[int], - disable_partial_transcripts: Optional[bool], - on_extra_session_information: Optional[ - Callable[[types.RealtimeSessionInformation], None] - ] = None, - ) -> None: - self._client = client - self._websocket: Optional[websockets.sync.client.ClientConnection] = None - - self._on_open = on_open - self._on_data = on_data - self._on_error = on_error - self._on_close = on_close - self._sample_rate = sample_rate - self._word_boost = word_boost - self._encoding = encoding - self._token = token - self._end_utterance_silence_threshold = end_utterance_silence_threshold - self._disable_partial_transcripts = disable_partial_transcripts - self._on_extra_session_information = on_extra_session_information - - self._write_queue: queue.Queue[Union[bytes, Dict]] = queue.Queue() - self._write_thread = threading.Thread(target=self._write) - self._read_thread = threading.Thread(target=self._read) - self._stop_event = threading.Event() - - def connect( - self, - timeout: Optional[float], - ) -> None: - """ - Connects to the real-time service. - - Args: - `timeout`: The maximum time to wait for the connection to be established. - """ - - params: Dict[str, Any] = { - "sample_rate": self._sample_rate, - } - if self._word_boost: - params["word_boost"] = json.dumps(self._word_boost) - if self._encoding: - params["encoding"] = self._encoding.value - if self._token: - params["token"] = self._token - if self._disable_partial_transcripts: - params["disable_partial_transcripts"] = self._disable_partial_transcripts - if self._on_extra_session_information: - params["enable_extra_session_information"] = True - - websocket_base_url = self._client.settings.base_url.replace("https", "wss") - - additional_headers = None - if self._token is None: - additional_headers = {"Authorization": f"{self._client.settings.api_key}"} - - try: - self._websocket = websocket_connect( - f"{websocket_base_url}{api.ENDPOINT_REALTIME_WEBSOCKET}?{urlencode(params)}", - additional_headers=additional_headers, - open_timeout=timeout, - ) - except Exception as exc: - return self._on_error( - types.RealtimeError( - f"Could not connect to the real-time service: {exc}" - ) - ) - - self._read_thread.start() - self._write_thread.start() - - if self._end_utterance_silence_threshold is not None: - self.configure_end_utterance_silence_threshold( - self._end_utterance_silence_threshold - ) - - def stream(self, data: bytes) -> None: - """ - Streams audio data to the real-time service by putting it into a queue. - """ - - self._write_queue.put(data) - - def configure_end_utterance_silence_threshold( - self, threshold_milliseconds: int - ) -> None: - """ - Configures the end of utterance silence threshold. - Can be called multiple times during a session at any point after the session starts. - - Args: - `threshold_milliseconds`: The threshold in milliseconds. - """ - - self._write_queue.put( - _RealtimeEndUtteranceSilenceThreshold(threshold_milliseconds).as_dict() - ) - - def force_end_utterance(self) -> None: - """ - Forces the end of the current utterance. - """ - - self._write_queue.put(_RealtimeForceEndUtterance().as_dict()) - - def close(self, terminate: bool = False) -> None: - """ - Closes the connection to the real-time service gracefully. - """ - if terminate and not self._stop_event.is_set(): - self._write_queue.put({"terminate_session": True}) - - try: - self._read_thread.join() - self._write_thread.join() - if self._websocket: - self._websocket.close() - except Exception: - pass - - if self._on_close: - self._on_close() - - def _read(self) -> None: - """ - Reads messages from the real-time service. - - Must run in a separate thread to avoid blocking the main thread. - """ - - while not self._stop_event.is_set(): - if not self._websocket: - raise ValueError("Websocket is None") - - try: - recv_message = self._websocket.recv(timeout=1) - except TimeoutError: - continue - except websockets.exceptions.ConnectionClosed as exc: - return self._handle_error(exc) - - try: - message = json.loads(recv_message) - except json.JSONDecodeError as exc: - self._on_error( - types.RealtimeError( - f"Could not decode message: {exc}", - ) - ) - continue - - self._handle_message(message) - - def _write(self) -> None: - """ - Writes messages to the real-time service. - - Must run in a separate thread to avoid blocking the main thread. - """ - - while not self._stop_event.is_set(): - try: - data = self._write_queue.get(timeout=1) - except queue.Empty: - continue - - try: - if not self._websocket: - raise ValueError("websocket is None") - elif isinstance(data, dict): - self._websocket.send(json.dumps(data)) - elif isinstance(data, bytes): - self._websocket.send(data) - else: - raise ValueError("unsupported message type") - except websockets.exceptions.ConnectionClosed as exc: - return self._handle_error(exc) - - def _handle_message( - self, - message: Dict[str, Any], - ) -> None: - """ - Handles a message received from the real-time service by calling the appropriate - callback. - - Args: - `message`: The message to handle. - """ - if "message_type" in message: - if message["message_type"] == types.RealtimeMessageTypes.partial_transcript: - self._on_data(types.RealtimePartialTranscript(**message)) - elif message["message_type"] == types.RealtimeMessageTypes.final_transcript: - self._on_data(types.RealtimeFinalTranscript(**message)) - elif ( - message["message_type"] == types.RealtimeMessageTypes.session_begins - and self._on_open - ): - self._on_open(types.RealtimeSessionOpened(**message)) - elif ( - message["message_type"] == types.RealtimeMessageTypes.session_terminated - ): - self._stop_event.set() - elif ( - message["message_type"] - == types.RealtimeMessageTypes.session_information - ): - if self._on_extra_session_information is not None: - self._on_extra_session_information( - types.RealtimeSessionInformation(**message) - ) - elif "error" in message: - self._on_error(types.RealtimeError(message["error"])) - - def _handle_error(self, error: websockets.exceptions.ConnectionClosed) -> None: - """ - Handles a WebSocket error by calling the appropriate callback. - - See a list of errors here: - - - https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number - - https://www.assemblyai.com/docs/Guides/real-time_streaming_transcription#closing-and-status-codes - """ - if ( - error.code >= 4000 - and error.code <= 4999 - and error.code in types.RealtimeErrorMapping - ): - error_message = types.RealtimeErrorMapping[error.code] - else: - error_message = error.reason - - if error.code != 1000: - self._on_error(types.RealtimeError(error_message, error.code)) - - self.close() - - @classmethod - def create_temporary_token( - cls, - expires_in: int, - timeout: Optional[float] = None, - ) -> str: - """ - Request a temporary authentication token. - - Args: - expires_in: The amount of time until the token expires in seconds. - timeout: The timeout in seconds to wait for a response. - A `timeout` of `None` means no timeout. - - Returns: The temporary authentication token. - """ - - return api.create_temporary_token( - client=_client.Client.get_default().http_client, - request=types.RealtimeCreateTemporaryTokenRequest( - expires_in=expires_in, - ), - http_timeout=timeout, - ) - - -class _RealtimeForceEndUtterance: - def as_dict(self) -> Dict[str, bool]: - return { - "force_end_utterance": True, - } - - -class _RealtimeEndUtteranceSilenceThreshold: - def __init__(self, threshold_milliseconds: int) -> None: - self._value = threshold_milliseconds - - @property - def value(self) -> int: - return self._value - - def as_dict(self) -> Dict[str, int]: - return {"end_utterance_silence_threshold": self._value} - - -class RealtimeTranscriber: - def __init__( - self, - *, - on_data: Callable[[types.RealtimeTranscript], None], - on_error: Callable[[types.RealtimeError], None], - on_open: Optional[Callable[[types.RealtimeSessionOpened], None]] = None, - on_close: Optional[Callable[[], None]] = None, - sample_rate: int, - word_boost: List[str] = [], - encoding: Optional[types.AudioEncoding] = None, - token: Optional[str] = None, - client: Optional[_client.Client] = None, - end_utterance_silence_threshold: Optional[int] = None, - disable_partial_transcripts: Optional[bool] = None, - on_extra_session_information: Optional[ - Callable[[types.RealtimeSessionInformation], None] - ] = None, - ) -> None: - """ - Creates a new real-time transcriber. - - Args: - `on_data`: The callback to call when a new transcript is received. - `on_error`: The callback to call when an error occurs. - `on_open`: (Optional) The callback to call when the connection to the real-time service opens. - `on_close`: (Optional) The callback to call when the connection to the real-time service closes. - `sample_rate`: The sample rate of the audio data. - `word_boost`: (Optional) A list of words to boost transcription probability for. - `encoding`: (Optional) The encoding of the audio data. - `token`: (Optional) A temporary authentication token. - `client`: (Optional) The client to use for the real-time service. - `end_utterance_silence_threshold`: (Optional) The end utterance silence threshold in milliseconds. - `disable_partial_transcripts`: (Optional) If set to `True`, only final transcripts will be received. - `on_extra_session_information`: (Optional) The callback to call when a `SessionInformation` message is received. - If this callback is set, the parameter `enable_extra_session_information` is sent to the API, and the client - receives a `SessionInformation` message right before receiving the session termination message. - """ - - self._client = client or _client.Client.get_default( - api_key_required=token is None - ) - - self._impl = _RealtimeTranscriberImpl( - on_open=on_open, - on_data=on_data, - on_error=on_error, - on_close=on_close, - sample_rate=sample_rate, - word_boost=word_boost, - encoding=encoding, - token=token, - client=self._client, - end_utterance_silence_threshold=end_utterance_silence_threshold, - disable_partial_transcripts=disable_partial_transcripts, - on_extra_session_information=on_extra_session_information, - ) - - def connect( - self, - timeout: Optional[float] = 10.0, - ) -> None: - """ - Connects to the real-time service. - - Args: - `timeout`: The timeout in seconds to wait for the connection to be established. - A `timeout` of `None` means no timeout. - """ - - self._impl.connect(timeout=timeout) - - def stream( - self, data: Union[bytes, Generator[bytes, None, None], Iterable[bytes]] - ) -> None: - """ - Streams raw audio data to the real-time service. - - Args: - `data`: Raw audio data in `bytes` or a generator/iterable of `bytes`. - - Note: Make sure that `data` matches the `sample_rate` that was given in the constructor. - """ - if isinstance(data, bytes): - self._impl.stream(data) - return - - for chunk in data: - self._impl.stream(chunk) - - def configure_end_utterance_silence_threshold( - self, threshold_milliseconds: int - ) -> None: - """ - Configures the silence duration threshold used to detect the end of an utterance. - In practice, it's used to tune how the transcriptions are split into final transcripts. - Can be called multiple times during a session at any point after the session starts. - - Args: - `threshold_milliseconds`: The threshold in milliseconds. - """ - self._impl.configure_end_utterance_silence_threshold(threshold_milliseconds) - - def force_end_utterance(self) -> None: - """ - Forces the end of the current utterance. - After calling this method, the server will end the current utterance and return a final transcript. - """ - self._impl.force_end_utterance() - - def close(self) -> None: - """ - Closes the connection to the real-time service. - """ - - self._impl.close(terminate=True) - - @classmethod - def create_temporary_token( - cls, - expires_in: int, - timeout: Optional[float] = None, - ) -> str: - """ - Request a temporary authentication token. - - Example: - To create a token, you can simply do: - ``` - token = aai.RealtimeTranscriber.create_temporary_token(expires_in=360000) - ``` - - Args: - expires_in: The amount of time until the token expires in seconds. - timeout: The timeout in seconds to wait for a response. - A `timeout` of `None` means no timeout. - - Returns: The temporary authentication token. - """ - return _RealtimeTranscriberImpl.create_temporary_token( - expires_in=expires_in, timeout=timeout - ) diff --git a/assemblyai/types.py b/assemblyai/types.py index f97d176..6f29e36 100644 --- a/assemblyai/types.py +++ b/assemblyai/types.py @@ -21,7 +21,7 @@ try: # pydantic v2 import - from pydantic import UUID4, BaseModel, ConfigDict, Field, field_validator + from pydantic import BaseModel, ConfigDict, Field, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict pydantic_v2 = True @@ -34,7 +34,7 @@ ) from None # pydantic v1 import (fallback for Python < 3.14) - from pydantic.v1 import UUID4, BaseModel, BaseSettings, ConfigDict, Field, validator + from pydantic.v1 import BaseModel, BaseSettings, ConfigDict, Field, validator pydantic_v2 = False @@ -2951,163 +2951,3 @@ class LemurPurgeResponse(BaseModel): deleted: bool "The result of the LeMUR purge request" - - -class RealtimeMessageTypes(str, Enum): - """ - The type of message received from the real-time API - """ - - partial_transcript = "PartialTranscript" - final_transcript = "FinalTranscript" - session_begins = "SessionBegins" - session_terminated = "SessionTerminated" - session_information = "SessionInformation" - - -class AudioEncoding(str, Enum): - """ - The encoding of the audio data - """ - - pcm_s16le = "pcm_s16le" - pcm_mulaw = "pcm_mulaw" - - -class RealtimeCreateTemporaryTokenRequest(BaseModel): - expires_in: int - "The amount of time until the token expires in seconds" - - -class RealtimeCreateTemporaryTokenResponse(BaseModel): - token: str - "The temporary authentication token for real-time transcription" - - -class RealtimeSessionOpened(BaseModel): - """ - Once a real-time session is opened, the client will receive this message - """ - - message_type: RealtimeMessageTypes = RealtimeMessageTypes.session_begins - - session_id: UUID4 - "Unique identifier for the established session." - - expires_at: datetime - "Timestamp when this session will expire." - - -class RealtimeWord(BaseModel): - """ - A word in a real-time transcript - """ - - start: int - "Start time of word relative to session start, in milliseconds" - - end: int - "End time of word relative to session start, in milliseconds" - - confidence: float - "The confidence score of the word, between 0 and 1" - - text: str - "The word itself" - - -class RealtimeTranscript(BaseModel): - """ - Base class for real-time transcript messages. - """ - - message_type: RealtimeMessageTypes - "Describes the type of message" - - audio_start: int - "Start time of audio sample relative to session start, in milliseconds" - - audio_end: int - "End time of audio sample relative to session start, in milliseconds" - - confidence: float - "The confidence score of the entire transcription, between 0 and 1" - - text: str - "The transcript for your audio" - - words: List[RealtimeWord] - """ - An array of objects, with the information for each word in the transcription text. - Will include the `start`/`end` time (in milliseconds) of the word, the `confidence` score of the word, - and the `text` (i.e. the word itself) - """ - - created: datetime - "Timestamp when this message was created" - - -class RealtimePartialTranscript(RealtimeTranscript): - """ - As you send audio data to the service, the service will immediately start responding with partial transcripts. - """ - - message_type: RealtimeMessageTypes = RealtimeMessageTypes.partial_transcript - - -class RealtimeFinalTranscript(RealtimeTranscript): - """ - After you've received your partial results, our model will continue to analyze incoming audio and, - when it detects the end of an "utterance" (usually a pause in speech), it will finalize the results - sent to you so far with higher accuracy, as well as add punctuation and casing to the transcription text. - """ - - message_type: RealtimeMessageTypes = RealtimeMessageTypes.final_transcript - - punctuated: bool - "Whether the transcript has been punctuated and cased" - - text_formatted: bool - "Whether the transcript has been formatted (e.g. Dollar -> $)" - - -class RealtimeSessionInformation(BaseModel): - """ - If `on_extra_session_information` is set, the client receives this message - right before receiving the session termination message. - """ - - message_type: RealtimeMessageTypes = RealtimeMessageTypes.session_information - - audio_duration_seconds: float - "The duration of the audio in seconds" - - -class RealtimeError(AssemblyAIError): - """ - Real-time error message - """ - - -RealtimeErrorMapping = { - 4000: "Sample rate must be a positive integer", - 4001: "Not Authorized", - 4002: "Insufficient Funds", - 4003: """This feature is paid-only and requires you to add a credit card. - Please visit https://app.assemblyai.com/ to add a credit card to your account""", - 4004: "Session Not Found", - 4008: "Session Expired", - 4010: "Session Previously Closed", - 4029: "Client sent audio too fast", - 4030: "Session is handled by another websocket", - 4031: "Session idle for too long", - 4032: "Audio duration is too short", - 4033: "Audio duration is too long", - 4034: "Audio too small to transcode", - 4100: "Endpoint received invalid JSON", - 4101: "Endpoint received a message with an invalid schema", - 4102: "This account has exceeded the number of allowed streams", - 4103: "The session has been reconnected. This websocket is no longer valid.", - 4104: "Could not parse word boost parameter", - 1013: "Temporary server condition forced blocking client's request", -} diff --git a/tests/unit/test_realtime_transcriber.py b/tests/unit/test_realtime_transcriber.py deleted file mode 100644 index c8d978a..0000000 --- a/tests/unit/test_realtime_transcriber.py +++ /dev/null @@ -1,564 +0,0 @@ -import datetime -import json -import uuid -from unittest.mock import MagicMock -from urllib.parse import urlencode - -import httpx -import pytest -import websockets.exceptions -from faker import Faker -from pytest_httpx import HTTPXMock -from pytest_mock import MockFixture - -import assemblyai as aai -from assemblyai.api import ENDPOINT_REALTIME_TOKEN - -aai.settings.api_key = "test" - - -def _disable_rw_threads(mocker: MockFixture): - """ - Disable the read/write threads for the websocket - """ - - mocker.patch("threading.Thread.start", return_value=None) - - -@pytest.mark.parametrize( - "encoding,token,expected_header", - [ - (None, None, {"Authorization": "test"}), - (aai.AudioEncoding.pcm_s16le, None, {"Authorization": "test"}), - (aai.AudioEncoding.pcm_mulaw, None, {"Authorization": "test"}), - (None, "12345678", None), - (aai.AudioEncoding.pcm_s16le, "12345678", None), - ], -) -def test_realtime_connect_has_parameters( - encoding, token, expected_header, mocker: MockFixture -): - """ - Test that the connect method has the correct parameters set - """ - aai.settings.base_url = "https://api.assemblyai.com" - - actual_url = None - actual_additional_headers = None - actual_open_timeout = None - - def mocked_websocket_connect( - url: str, additional_headers: dict, open_timeout: float - ): - nonlocal actual_url, actual_additional_headers, actual_open_timeout - actual_url = url - actual_additional_headers = additional_headers - actual_open_timeout = open_timeout - - mocker.patch( - "assemblyai.transcriber.websocket_connect", - new=mocked_websocket_connect, - ) - _disable_rw_threads(mocker) - - transcriber = aai.RealtimeTranscriber( - on_data=lambda: None, - on_error=lambda error: print(error), - sample_rate=44_100, - word_boost=["AssemblyAI"], - encoding=encoding, - token=token, - ) - - transcriber.connect(timeout=15.0) - - params = dict(sample_rate=44100, word_boost=json.dumps(["AssemblyAI"])) - if encoding: - params["encoding"] = encoding.value - if token: - params["token"] = token - - assert actual_url == f"wss://api.assemblyai.com/v2/realtime/ws?{urlencode(params)}" - assert actual_additional_headers == expected_header - assert actual_open_timeout == 15.0 - - -def test_realtime_connect_succeeds(mocker: MockFixture): - """ - Tests that the `RealtimeTranscriber` successfully connects to the `real-time` service. - """ - on_error_called = False - - def on_error(error: aai.RealtimeError): - nonlocal on_error_called - on_error_called = True - - transcriber = aai.RealtimeTranscriber( - on_data=lambda _: None, - on_error=on_error, - sample_rate=44_100, - ) - - mocker.patch( - "assemblyai.transcriber.websocket_connect", - return_value=MagicMock(), - ) - - # mock the read/write threads - _disable_rw_threads(mocker) - - # should pass - transcriber.connect() - - # no errors should be called - assert not on_error_called - - -def test_realtime_token_connect_succeeds(mocker: MockFixture): - """ - Tests that the `RealtimeTranscriber` successfully connects - to the `real-time` service when a token is used. - """ - on_error_called = False - - # reset the API key - mocker.patch("assemblyai.settings.api_key", new=None) - - def on_error(error: aai.RealtimeError): - nonlocal on_error_called - on_error_called = True - - transcriber = aai.RealtimeTranscriber( - on_data=lambda _: None, on_error=on_error, sample_rate=44_100, token="12345" - ) - - mocker.patch( - "assemblyai.transcriber.websocket_connect", - return_value=MagicMock(), - ) - - # mock the read/write threads - _disable_rw_threads(mocker) - - # should pass - transcriber.connect() - - # no errors should be called - assert not on_error_called - - -def test_realtime_connect_fails(mocker: MockFixture): - """ - Tests that the `RealtimeTranscriber` fails to connect to the `real-time` service. - """ - - on_error_called = False - - def on_error(error: aai.RealtimeError): - nonlocal on_error_called - on_error_called = True - - assert isinstance(error, aai.RealtimeError) - assert "connection failed" in str(error) - - transcriber = aai.RealtimeTranscriber( - on_data=lambda _: None, - on_error=on_error, - sample_rate=44_100, - ) - mocker.patch( - "assemblyai.transcriber.websocket_connect", - side_effect=Exception("connection failed"), - ) - - transcriber.connect() - - assert on_error_called - - -def test_realtime__read_succeeds(mocker: MockFixture, faker: Faker): - """ - Tests the `_read` method of the `_RealtimeTranscriberImpl` class. - """ - - expected_transcripts = [ - aai.RealtimeFinalTranscript( - created=faker.date_time(), - text=faker.sentence(), - audio_start=0, - audio_end=1, - confidence=1.0, - words=[], - punctuated=True, - text_formatted=True, - ) - ] - - received_transcripts = [] - - def on_data(data: aai.RealtimeTranscript): - nonlocal received_transcripts - received_transcripts.append(data) - - transcriber = aai.RealtimeTranscriber( - on_data=on_data, - on_error=lambda _: None, - sample_rate=44_100, - ) - - transcriber._impl._websocket = MagicMock() - websocket_recv = [ - json.dumps(msg.dict(), default=str) for msg in expected_transcripts - ] - transcriber._impl._websocket.recv.side_effect = websocket_recv - - with pytest.raises(StopIteration): - transcriber._impl._read() - - assert received_transcripts == expected_transcripts - - -def test_realtime__read_fails(mocker: MockFixture): - """ - Tests the `_read` method of the `_RealtimeTranscriberImpl` class. - """ - - on_error_called = False - - def on_error(error: aai.RealtimeError): - nonlocal on_error_called - on_error_called = True - - transcriber = aai.RealtimeTranscriber( - on_data=lambda _: None, - on_error=on_error, - sample_rate=44_100, - ) - - transcriber._impl._websocket = MagicMock() - error = websockets.exceptions.ConnectionClosedOK(rcvd=None, sent=None) - transcriber._impl._websocket.recv.side_effect = error - - transcriber._impl._read() - - assert on_error_called - - -def test_realtime__write_succeeds(mocker: MockFixture): - """ - Tests the `_write` method of the `_RealtimeTranscriberImpl` class. - """ - audio_chunks = [ - bytes([1, 2, 3, 4, 5]), - bytes([6, 7, 8, 9, 10]), - ] - - actual_sent = [] - - def mocked_send(data: str): - nonlocal actual_sent - actual_sent.append(data) - - transcriber = aai.RealtimeTranscriber( - on_data=lambda _: None, - on_error=lambda _: None, - sample_rate=44_100, - ) - - transcriber._impl._websocket = MagicMock() - transcriber._impl._websocket.send = mocked_send - transcriber._impl._stop_event.is_set = MagicMock(side_effect=[False, False, True]) - - transcriber.stream(audio_chunks[0]) - transcriber.stream(audio_chunks[1]) - - transcriber._impl._write() - - # assert that the correct data was sent (= the exact input bytes) - assert len(actual_sent) == 2 - assert actual_sent[0] == audio_chunks[0] - assert actual_sent[1] == audio_chunks[1] - - -def test_realtime__handle_message_session_begins(mocker: MockFixture): - """ - Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class - with the `SessionBegins` message. - """ - - test_message = { - "message_type": "SessionBegins", - "session_id": str(uuid.uuid4()), - "expires_at": datetime.datetime.now().isoformat(), - } - - on_open_called = False - - def on_open(session_opened: aai.RealtimeSessionOpened): - nonlocal on_open_called - on_open_called = True - assert isinstance(session_opened, aai.RealtimeSessionOpened) - assert session_opened.session_id == uuid.UUID(test_message["session_id"]) - assert session_opened.expires_at.isoformat() == test_message["expires_at"] - - transcriber = aai.RealtimeTranscriber( - on_open=on_open, - on_data=lambda _: None, - on_error=lambda _: None, - sample_rate=44_100, - ) - - transcriber._impl._handle_message(test_message) - - assert on_open_called - - -def test_realtime__handle_message_partial_transcript(mocker: MockFixture): - """ - Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class - with the `PartialTranscript` message. - """ - - test_message = { - "message_type": "PartialTranscript", - "text": "hello world", - "audio_start": 0, - "audio_end": 1500, - "confidence": 0.99, - "created": datetime.datetime.now().isoformat(), - "words": [ - { - "text": "hello", - "start": 0, - "end": 500, - "confidence": 0.99, - }, - { - "text": "world", - "start": 500, - "end": 1500, - "confidence": 0.99, - }, - ], - } - - on_data_called = False - - def on_data(data: aai.RealtimePartialTranscript): - nonlocal on_data_called - on_data_called = True - assert isinstance(data, aai.RealtimePartialTranscript) - assert data.text == test_message["text"] - assert data.audio_start == test_message["audio_start"] - assert data.audio_end == test_message["audio_end"] - assert data.confidence == test_message["confidence"] - assert data.created.isoformat() == test_message["created"] - assert data.words == [ - aai.RealtimeWord( - text=test_message["words"][0]["text"], - start=test_message["words"][0]["start"], - end=test_message["words"][0]["end"], - confidence=test_message["words"][0]["confidence"], - ), - aai.RealtimeWord( - text=test_message["words"][1]["text"], - start=test_message["words"][1]["start"], - end=test_message["words"][1]["end"], - confidence=test_message["words"][1]["confidence"], - ), - ] - - transcriber = aai.RealtimeTranscriber( - on_data=on_data, - on_error=lambda _: None, - sample_rate=44_100, - ) - - transcriber._impl._handle_message(test_message) - - assert on_data_called - - -def test_realtime__handle_message_final_transcript(mocker: MockFixture): - """ - Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class - with the `FinalTranscript` message. - """ - - test_message = { - "message_type": "FinalTranscript", - "text": "Hello, world!", - "audio_start": 0, - "audio_end": 1500, - "confidence": 0.99, - "created": datetime.datetime.now().isoformat(), - "punctuated": True, - "text_formatted": True, - "words": [ - { - "text": "Hello,", - "start": 0, - "end": 500, - "confidence": 0.99, - }, - { - "text": "world!", - "start": 500, - "end": 1500, - "confidence": 0.99, - }, - ], - } - - on_data_called = False - - def on_data(data: aai.RealtimeFinalTranscript): - nonlocal on_data_called - on_data_called = True - assert isinstance(data, aai.RealtimeFinalTranscript) - assert data.text == test_message["text"] - assert data.audio_start == test_message["audio_start"] - assert data.audio_end == test_message["audio_end"] - assert data.confidence == test_message["confidence"] - assert data.created.isoformat() == test_message["created"] - assert data.punctuated == test_message["punctuated"] - assert data.text_formatted == test_message["text_formatted"] - assert data.words == [ - aai.RealtimeWord( - text=test_message["words"][0]["text"], - start=test_message["words"][0]["start"], - end=test_message["words"][0]["end"], - confidence=test_message["words"][0]["confidence"], - ), - aai.RealtimeWord( - text=test_message["words"][1]["text"], - start=test_message["words"][1]["start"], - end=test_message["words"][1]["end"], - confidence=test_message["words"][1]["confidence"], - ), - ] - - transcriber = aai.RealtimeTranscriber( - on_data=on_data, - on_error=lambda _: None, - sample_rate=44_100, - ) - - transcriber._impl._handle_message(test_message) - - assert on_data_called - - -def test_realtime__handle_message_error_message(mocker: MockFixture): - """ - Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class - with the error message. - """ - - test_message = { - "error": "test error", - } - - on_error_called = False - - def on_error(error: aai.RealtimeError): - nonlocal on_error_called - on_error_called = True - assert isinstance(error, aai.RealtimeError) - assert str(error) == test_message["error"] - - transcriber = aai.RealtimeTranscriber( - on_data=lambda _: None, - on_error=on_error, - sample_rate=44_100, - ) - - transcriber._impl._handle_message(test_message) - - assert on_error_called - - -def test_realtime__handle_message_session_information_message(mocker: MockFixture): - """ - Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class - with the session information message. - """ - - test_message = { - "message_type": "SessionInformation", - "audio_duration_seconds": 3000.0, - } - - on_extra_session_information_called = False - - def on_extra_session_information(data: aai.RealtimeSessionInformation): - nonlocal on_extra_session_information_called - on_extra_session_information_called = True - assert isinstance(data, aai.RealtimeSessionInformation) - assert data.audio_duration_seconds == test_message["audio_duration_seconds"] - - transcriber = aai.RealtimeTranscriber( - on_data=lambda _: None, - on_error=lambda _: None, - sample_rate=44_100, - on_extra_session_information=on_extra_session_information, - ) - - transcriber._impl._handle_message(test_message) - - assert on_extra_session_information_called - - -def test_realtime__handle_message_unknown_message(mocker: MockFixture): - """ - Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class - with an unknown message. - """ - - test_message = { - "message_type": "Unknown", - } - - on_data_called = False - - def on_data(data: aai.RealtimeTranscript): - nonlocal on_data_called - on_data_called = True - - on_error_called = False - - def on_error(error: aai.RealtimeError): - nonlocal on_error_called - on_error_called = True - - transcriber = aai.RealtimeTranscriber( - on_data=on_data, - on_error=on_error, - sample_rate=44_100, - ) - - transcriber._impl._handle_message(test_message) - - assert not on_data_called - assert not on_error_called - - -def test_create_temporary_token(httpx_mock: HTTPXMock): - """ - Tests whether the creation of a temporary token is successful. - """ - - # mock the specific endpoint - httpx_mock.add_response( - url=f"{aai.settings.base_url}{ENDPOINT_REALTIME_TOKEN}", - status_code=httpx.codes.OK, - method="POST", - json={"token": "123456"}, - ) - - token = aai.RealtimeTranscriber.create_temporary_token(expires_in=3000) - - assert token == "123456" - - -# TODO: create tests for the `RealtimeTranscriber.close` method diff --git a/tests/unit/test_streaming.py b/tests/unit/test_streaming.py index eb40850..a637eba 100644 --- a/tests/unit/test_streaming.py +++ b/tests/unit/test_streaming.py @@ -681,6 +681,37 @@ def mocked_websocket_connect( assert "interruption_delay=500" in actual_url +def test_client_connect_with_turn_left_pad_ms(mocker: MockFixture): + # Given: client + turn_left_pad_ms=1024 (U3-Pro left-pad window override) + actual_url = None + + def mocked_websocket_connect( + url: str, additional_headers: dict, open_timeout: float + ): + nonlocal actual_url + actual_url = url + + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + new=mocked_websocket_connect, + ) + _disable_rw_threads(mocker) + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + params = StreamingParameters( + sample_rate=16000, + speech_model=SpeechModel.u3_rt_pro, + turn_left_pad_ms=1024, + ) + + # When: connect + client.connect(params) + + # Then: parameter reaches the URL + assert "turn_left_pad_ms=1024" in actual_url + + def test_customer_support_audio_capture_warns_when_enabled( mocker: MockFixture, caplog: pytest.LogCaptureFixture ): diff --git a/tests/unit/test_streaming_async.py b/tests/unit/test_streaming_async.py new file mode 100644 index 0000000..bf00701 --- /dev/null +++ b/tests/unit/test_streaming_async.py @@ -0,0 +1,1153 @@ +import asyncio +import json +import logging +from urllib.parse import urlencode + +import pytest +from pytest_mock import MockFixture +from websockets.exceptions import ConnectionClosed, InvalidStatus +from websockets.frames import Close + +from assemblyai.streaming.v3 import ( + AsyncStreamingClient, + SpeechModel, + StreamingClientOptions, + StreamingEvents, + StreamingParameters, +) +from assemblyai.streaming.v3.models import TerminateSession + +pytestmark = pytest.mark.asyncio + + +def _default_params() -> StreamingParameters: + return StreamingParameters( + sample_rate=16000, + speech_model=SpeechModel.universal_streaming_english, + ) + + +class _FakeAsyncWebSocket: + """Programmable async websocket stand-in for driving AsyncStreamingClient + in tests. Inbound messages are queued via ``push_message`` / + ``push_close``; outbound sends accumulate in ``sent``. + """ + + def __init__(self, send_raises=None): + self._inbound: "asyncio.Queue[object]" = asyncio.Queue() + self._send_raises = send_raises + self.sent: list = [] + self.send_call_count = 0 + self.close_call_count = 0 + self._closed = False + + def push_message(self, data) -> None: + self._inbound.put_nowait(data) + + def push_close(self, exc: BaseException) -> None: + self._inbound.put_nowait(exc) + + async def recv(self): + item = await self._inbound.get() + if isinstance(item, BaseException): + raise item + return item + + async def send(self, data) -> None: + self.send_call_count += 1 + if self._send_raises is not None: + raise self._send_raises + self.sent.append(data) + + async def close(self) -> None: + self.close_call_count += 1 + self._closed = True + + +def _patch_connect(mocker: MockFixture, fake_ws): + """Patch ``websocket_connect_async`` to return the given fake websocket.""" + + async def fake_connect(uri, additional_headers=None, **_kwargs): + fake_connect.uri = uri + fake_connect.additional_headers = additional_headers + return fake_ws + + fake_connect.uri = None + fake_connect.additional_headers = None + mocker.patch( + "assemblyai.streaming.v3.async_client.websocket_connect_async", + new=fake_connect, + ) + return fake_connect + + +async def _wait_for_tasks(client: AsyncStreamingClient, timeout: float = 2.0) -> None: + """Wait until both read/write tasks have exited and stop is set. Raises + ``AssertionError`` on timeout so stalls fail tests deterministically + instead of silently passing.""" + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout + while loop.time() < deadline: + read_done = client._read_task is None or client._read_task.done() + write_done = client._write_task is None or client._write_task.done() + if read_done and write_done and client._stop_event.is_set(): + return + await asyncio.sleep(0.01) + raise AssertionError( + f"AsyncStreamingClient read/write tasks did not finish within {timeout}s" + ) + + +async def test_client_connect_builds_uri_and_headers(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + fake_connect = _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + + params = _default_params() + await client.connect(params) + + expected_qs = urlencode( + { + "sample_rate": params.sample_rate, + "speech_model": str(params.speech_model), + } + ) + assert fake_connect.uri == f"wss://api.example.com/v3/ws?{expected_qs}" + assert fake_connect.additional_headers["Authorization"] == "test" + assert fake_connect.additional_headers["AssemblyAI-Version"] == "2025-05-12" + assert "AssemblyAI/1.0" in fake_connect.additional_headers["User-Agent"] + + await client.disconnect() + + +async def test_client_connect_with_token(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + fake_connect = _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(token="tok-value", api_host="api.example.com") + ) + await client.connect(_default_params()) + + assert fake_connect.additional_headers["Authorization"] == "tok-value" + + await client.disconnect() + + +async def test_stream_bytes_writes_to_socket(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + await client.connect(_default_params()) + + await client.stream(b"\x00" * 320) + + # Give the write task a moment to drain the queue. + for _ in range(50): + if fake_ws.sent: + break + await asyncio.sleep(0.01) + + assert fake_ws.sent == [b"\x00" * 320] + + await client.disconnect() + + +async def test_stream_sync_iterable(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + await client.connect(_default_params()) + + chunks = [b"a", b"bb", b"ccc"] + await client.stream(iter(chunks)) + + for _ in range(50): + if len(fake_ws.sent) == 3: + break + await asyncio.sleep(0.01) + + assert fake_ws.sent == chunks + + await client.disconnect() + + +async def test_stream_async_iterable(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + await client.connect(_default_params()) + + async def gen(): + for chunk in (b"x", b"yy", b"zzz"): + yield chunk + + await client.stream(gen()) + + for _ in range(50): + if len(fake_ws.sent) == 3: + break + await asyncio.sleep(0.01) + + assert fake_ws.sent == [b"x", b"yy", b"zzz"] + + await client.disconnect() + + +async def test_disconnect_terminate_sends_terminate_then_closes(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + await client.connect(_default_params()) + + await client.disconnect(terminate=True) + + sent_terminate = [ + s for s in fake_ws.sent if isinstance(s, str) and "Terminate" in s + ] + assert len(sent_terminate) == 1 + assert fake_ws.close_call_count >= 1 + + +async def test_begin_event_dispatched_to_handler(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + received = [] + + def on_begin(_client, event): + received.append(event) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Begin, on_begin) + await client.connect(_default_params()) + + fake_ws.push_message( + json.dumps( + { + "type": "Begin", + "id": "abc", + "expires_at": "2030-01-01T00:00:00", + } + ) + ) + + for _ in range(50): + if received: + break + await asyncio.sleep(0.01) + + assert len(received) == 1 + assert received[0].id == "abc" + + await client.disconnect() + + +async def test_async_handler_is_awaited(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + seen = [] + + async def on_begin(_client, event): + await asyncio.sleep(0) + seen.append(event.id) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Begin, on_begin) + await client.connect(_default_params()) + + fake_ws.push_message( + json.dumps( + {"type": "Begin", "id": "async-id", "expires_at": "2030-01-01T00:00:00"} + ) + ) + + for _ in range(50): + if seen: + break + await asyncio.sleep(0.01) + + assert seen == ["async-id"] + + await client.disconnect() + + +async def test_sync_and_async_handlers_can_mix(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + sync_seen = [] + async_seen = [] + + def sync_handler(_client, event): + sync_seen.append(event.id) + + async def async_handler(_client, event): + async_seen.append(event.id) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Begin, sync_handler) + client.on(StreamingEvents.Begin, async_handler) + await client.connect(_default_params()) + + fake_ws.push_message( + json.dumps({"type": "Begin", "id": "mix", "expires_at": "2030-01-01T00:00:00"}) + ) + + for _ in range(50): + if sync_seen and async_seen: + break + await asyncio.sleep(0.01) + + assert sync_seen == ["mix"] + assert async_seen == ["mix"] + + await client.disconnect() + + +async def test_error_event_then_close_fires_only_once( + mocker: MockFixture, caplog: pytest.LogCaptureFixture +): + caplog.set_level(logging.ERROR) + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + received = [] + + def on_error(_client, err): + received.append(err) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Error, on_error) + await client.connect(_default_params()) + + fake_ws.push_message( + json.dumps({"type": "Error", "error": "Invalid API key", "error_code": 4001}) + ) + fake_ws.push_close(ConnectionClosed(rcvd=Close(4001, "Not Authorized"), sent=None)) + + await _wait_for_tasks(client) + + assert len(received) == 1 + assert str(received[0]) == "Invalid API key" + assert received[0].code == 4001 + + error_logs = [ + rec + for rec in caplog.records + if "Streaming error" in rec.message and "4001" in rec.message + ] + close_logs = [ + rec + for rec in caplog.records + if "Connection closed" in rec.message and "4001" in rec.message + ] + assert len(error_logs) == 1 + # ``_report_server_error`` closes the websocket locally and sets stop, so + # the read loop exits before the pushed trailing close is recv'd. No close + # log is emitted in this path — the Error event already captured the cause. + assert close_logs == [] + + await client.disconnect() + + +async def test_server_error_without_trailing_close_tears_down(mocker: MockFixture): + """Regression: a server ``Error`` frame with no trailing close must still + drive the read loop to exit. Without local teardown in + ``_report_server_error``, ``await ws.recv()`` would block indefinitely.""" + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + received = [] + + def on_error(_client, err): + received.append(err) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Error, on_error) + await client.connect(_default_params()) + + # Push an Error frame and nothing else — no trailing close. + fake_ws.push_message( + json.dumps({"type": "Error", "error": "boom", "error_code": 4002}) + ) + + # If teardown is missing this raises AssertionError after timeout. + await _wait_for_tasks(client) + + assert len(received) == 1 + assert received[0].code == 4002 + assert fake_ws.close_call_count >= 1 + + await client.disconnect() + + +async def test_clean_close_emits_no_error_or_log( + mocker: MockFixture, caplog: pytest.LogCaptureFixture +): + caplog.set_level(logging.ERROR) + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + received = [] + + def on_error(_client, err): + received.append(err) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Error, on_error) + await client.connect(_default_params()) + + fake_ws.push_close(ConnectionClosed(rcvd=Close(1000, "session ended"), sent=None)) + + await _wait_for_tasks(client) + + assert received == [] + error_logs = [rec for rec in caplog.records if rec.levelno >= logging.ERROR] + assert error_logs == [] + + await client.disconnect() + + +async def test_turn_handler_exception_does_not_kill_read_task(mocker: MockFixture): + """A raising Turn handler must not propagate out of the read task; the + next inbound message should still be delivered.""" + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + seen = [] + + def bad_handler(_client, _turn): + raise RuntimeError("boom") + + def good_handler(_client, turn): + seen.append(turn.end_of_turn) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Turn, bad_handler) + client.on(StreamingEvents.Turn, good_handler) + await client.connect(_default_params()) + + turn_payload = { + "type": "Turn", + "turn_order": 1, + "turn_is_formatted": False, + "end_of_turn": False, + "transcript": "hello", + "end_of_turn_confidence": 0.5, + "words": [], + } + fake_ws.push_message(json.dumps(turn_payload)) + fake_ws.push_message(json.dumps({**turn_payload, "turn_order": 2})) + + for _ in range(100): + if len(seen) == 2: + break + await asyncio.sleep(0.01) + + assert seen == [False, False] + + await client.disconnect() + + +async def test_warning_handler_exception_does_not_kill_read_task(mocker: MockFixture): + """A raising Warning handler must not propagate out of the read task.""" + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + received = [] + + def bad_handler(_client, _warning): + raise RuntimeError("boom") + + def good_handler(_client, warning): + received.append(warning.warning_code) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Warning, bad_handler) + client.on(StreamingEvents.Warning, good_handler) + await client.connect(_default_params()) + + fake_ws.push_message( + json.dumps({"type": "Warning", "warning": "first", "warning_code": 1}) + ) + fake_ws.push_message( + json.dumps({"type": "Warning", "warning": "second", "warning_code": 2}) + ) + + for _ in range(100): + if len(received) == 2: + break + await asyncio.sleep(0.01) + + assert received == [1, 2] + + await client.disconnect() + + +async def test_stream_before_connect_raises_runtime_error(): + """``stream()`` called before ``connect()`` must raise RuntimeError rather + than silently dropping data. Silent drop would diverge from the sync client + (which buffers pre-connect data) in a way that's easy to miss — explicit + failure surfaces the misuse.""" + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + + async def gen(): + yield b"x" + + for data in (b"\x00" * 10, iter([b"a", b"b"]), gen()): + with pytest.raises(RuntimeError, match="not connected"): + await client.stream(data) + + +async def test_set_params_before_connect_raises_runtime_error(): + from assemblyai.streaming.v3 import ( + StreamingSessionParameters, + ) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + with pytest.raises(RuntimeError, match="not connected"): + await client.set_params(StreamingSessionParameters(min_turn_silence=200)) + + +async def test_force_endpoint_before_connect_raises_runtime_error(): + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + with pytest.raises(RuntimeError, match="not connected"): + await client.force_endpoint() + + +async def test_stream_after_close_is_noop(mocker: MockFixture): + """Post-close ``stream()`` must stay a silent no-op so user cleanup paths + (e.g. a finally block draining a queue) don't have to wrap each call in + try/except. Pre-connect raise + post-close no-op gives both: misuse is + loud, cleanup is quiet.""" + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + await client.connect(_default_params()) + + # Simulate a clean close — read task exits, _stop_event is set. + fake_ws.push_close(ConnectionClosed(rcvd=Close(1000, "bye"), sent=None)) + await _wait_for_tasks(client) + + # No raise: post-close stream is safe for cleanup. + await client.stream(b"\x00" * 10) + await client.disconnect() + + +async def test_handler_exception_does_not_block_shutdown(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + def bad_handler(_client, _err): + raise RuntimeError("boom") + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Error, bad_handler) + await client.connect(_default_params()) + + fake_ws.push_close(ConnectionClosed(rcvd=Close(1011, "server error"), sent=None)) + + await _wait_for_tasks(client) + # If the handler exception had escaped, _wait_for_tasks would time out. + assert client._read_task.done() + + await client.disconnect() + + +async def test_invalid_status_during_connect_dispatches_error(mocker: MockFixture): + received = [] + + def on_error(_client, err): + received.append(err) + + response = type("R", (), {"status_code": 401})() + err = InvalidStatus(response=response) + + async def failing_connect(*_args, **_kwargs): + raise err + + mocker.patch( + "assemblyai.streaming.v3.async_client.websocket_connect_async", + new=failing_connect, + ) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Error, on_error) + + await client.connect(_default_params()) + + assert len(received) == 1 + assert received[0].code == 401 + assert "HTTP 401" in str(received[0]) + + +async def test_terminate_session_bypasses_stop_gate(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + await client.connect(_default_params()) + + # Pre-set stop, then queue a TerminateSession directly. The write loop must + # still send it before exiting. + client._stop_event.set() + await client._write_queue.put(TerminateSession()) + + for _ in range(100): + if fake_ws.send_call_count >= 1: + break + await asyncio.sleep(0.01) + + assert fake_ws.send_call_count >= 1 + assert any(isinstance(s, str) and "Terminate" in s for s in fake_ws.sent) + + await client.disconnect() + + +async def test_create_temporary_token(mocker: MockFixture): + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + + captured = {} + + async def fake_get(self, url, params=None): + captured["url"] = url + captured["params"] = params + + class R: + def raise_for_status(self_inner): + pass + + def json(self_inner): + return {"token": "tmp-tok"} + + return R() + + mocker.patch("httpx.AsyncClient.get", new=fake_get) + + token = await client.create_temporary_token( + expires_in_seconds=60, max_session_duration_seconds=600 + ) + assert token == "tmp-tok" + assert captured["url"] == "/v3/token" + assert captured["params"] == { + "expires_in_seconds": 60, + "max_session_duration_seconds": 600, + } + + await client._client.aclose() + + +async def test_create_temporary_token_forwards_zero_expires(mocker: MockFixture): + """Regression: ``expires_in_seconds=0`` must reach the server (so it can + reject it with a clear error) rather than being silently dropped by a + falsy check.""" + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + + captured = {} + + async def fake_get(self, url, params=None): + captured["params"] = params + + class R: + def raise_for_status(self_inner): + pass + + def json(self_inner): + return {"token": "tmp-tok"} + + return R() + + mocker.patch("httpx.AsyncClient.get", new=fake_get) + + await client.create_temporary_token(expires_in_seconds=0) + + assert captured["params"] == {"expires_in_seconds": 0} + + await client._client.aclose() + + +async def test_connect_twice_raises(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + await client.connect(_default_params()) + + with pytest.raises(RuntimeError, match="already been connected"): + await client.connect(_default_params()) + + await client.disconnect() + + +async def test_connect_after_handshake_failure_raises(mocker: MockFixture): + """Regression: a failed connect leaves ``_connection_closed_reported`` set + and ``_stop_event`` set. A second ``connect()`` attempt on the same client + must surface a clear error, not silently produce a dead read/write loop.""" + response = type("R", (), {"status_code": 401})() + err = InvalidStatus(response=response) + + async def failing_connect(*_args, **_kwargs): + raise err + + mocker.patch( + "assemblyai.streaming.v3.async_client.websocket_connect_async", + new=failing_connect, + ) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + + await client.connect(_default_params()) + + with pytest.raises(RuntimeError, match="already been connected"): + await client.connect(_default_params()) + + +async def test_set_params_enqueues_update_configuration(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + await client.connect(_default_params()) + + from assemblyai.streaming.v3.models import ( + StreamingSessionParameters, + ) + + await client.set_params( + StreamingSessionParameters(end_of_turn_confidence_threshold=0.5) + ) + + for _ in range(100): + update_frames = [ + s for s in fake_ws.sent if isinstance(s, str) and "UpdateConfiguration" in s + ] + if update_frames: + break + await asyncio.sleep(0.01) + + update_frames = [ + s for s in fake_ws.sent if isinstance(s, str) and "UpdateConfiguration" in s + ] + assert len(update_frames) == 1 + payload = json.loads(update_frames[0]) + assert payload["type"] == "UpdateConfiguration" + assert payload["end_of_turn_confidence_threshold"] == 0.5 + + await client.disconnect() + + +async def test_force_endpoint_enqueues_force_endpoint_frame(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + await client.connect(_default_params()) + + await client.force_endpoint() + + for _ in range(100): + force_frames = [ + s for s in fake_ws.sent if isinstance(s, str) and "ForceEndpoint" in s + ] + if force_frames: + break + await asyncio.sleep(0.01) + + force_frames = [ + s for s in fake_ws.sent if isinstance(s, str) and "ForceEndpoint" in s + ] + assert len(force_frames) == 1 + payload = json.loads(force_frames[0]) + assert payload["type"] == "ForceEndpoint" + + await client.disconnect() + + +async def test_warning_event_dispatched_to_handler(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + received = [] + + def on_warning(_client, event): + received.append(event) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Warning, on_warning) + await client.connect(_default_params()) + + fake_ws.push_message( + json.dumps({"type": "Warning", "warning": "slow audio", "warning_code": 1234}) + ) + + for _ in range(100): + if received: + break + await asyncio.sleep(0.01) + + assert len(received) == 1 + assert received[0].warning == "slow audio" + assert received[0].warning_code == 1234 + + await client.disconnect() + + +async def test_termination_event_sets_stop_and_dispatches(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + received = [] + + def on_termination(_client, event): + received.append(event) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Termination, on_termination) + await client.connect(_default_params()) + + fake_ws.push_message( + json.dumps( + { + "type": "Termination", + "audio_duration_seconds": 12, + "session_duration_seconds": 15, + } + ) + ) + + # Termination sets stop_event but doesn't close the socket; wait for the + # handler to fire and stop_event to flip. + for _ in range(100): + if received and client._stop_event is not None and client._stop_event.is_set(): + break + await asyncio.sleep(0.01) + + assert len(received) == 1 + assert client._stop_event is not None + assert client._stop_event.is_set() + + await client.disconnect() + + +async def test_disconnect_before_connect_is_safe_noop(mocker: MockFixture): + """``disconnect()`` is safe before ``connect()``. With the httpx client + lazy-constructed (no work done in ``__init__``), there is nothing to close + on a never-used client, so ``aclose`` should not be invoked.""" + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + + closed = [] + + async def fake_aclose(self): + closed.append(True) + + mocker.patch("httpx.AsyncClient.aclose", new=fake_aclose) + + await client.disconnect() + + # Nothing was ever instantiated, so nothing to close. + assert closed == [] + assert client._read_task is None + assert client._write_task is None + + +async def test_construct_only_does_not_instantiate_httpx_client( + mocker: MockFixture, +): + """Constructing an ``AsyncStreamingClient`` and never calling + ``connect()`` / ``create_temporary_token()`` / ``disconnect()`` must not + instantiate an ``httpx.AsyncClient`` — otherwise an unused client leaks + the pool. The HTTP client should be built lazily on first use.""" + import httpx + + constructed = [] + real_init = httpx.AsyncClient.__init__ + + def counting_init(self, *args, **kwargs): + constructed.append(True) + return real_init(self, *args, **kwargs) + + mocker.patch.object(httpx.AsyncClient, "__init__", counting_init) + + AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + + assert constructed == [], ( + "AsyncStreamingClient should not eagerly instantiate httpx.AsyncClient; " + "got constructions: " + str(constructed) + ) + + +async def test_async_context_manager_calls_disconnect_on_exit(mocker: MockFixture): + """``async with AsyncStreamingClient(opts) as c:`` must invoke + ``disconnect()`` on block exit so callers can't forget cleanup.""" + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + async with AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) as client: + await client.connect(_default_params()) + await client.stream(b"\x00" * 32) + + # On exit, disconnect should have torn down read/write tasks. + assert client._read_task is not None and client._read_task.done() + assert client._write_task is not None and client._write_task.done() + assert client._stop_event is not None and client._stop_event.is_set() + + +async def test_async_context_manager_disconnect_runs_on_exception( + mocker: MockFixture, +): + """Exception inside the ``async with`` body must still trigger + ``disconnect()`` so the websocket / http client don't leak when user + code raises.""" + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + class _Boom(Exception): + pass + + client_ref = {} + + with pytest.raises(_Boom): + async with AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) as client: + client_ref["c"] = client + await client.connect(_default_params()) + raise _Boom() + + client = client_ref["c"] + assert client._stop_event is not None and client._stop_event.is_set() + assert client._websocket is None or fake_ws.close_call_count >= 1 + + +async def test_disconnect_closes_http_client_when_used(mocker: MockFixture): + """Once the lazy ``httpx.AsyncClient`` has been instantiated (by a call + that goes through HTTP — e.g. ``create_temporary_token``), ``disconnect`` + must close it so the pool doesn't leak.""" + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + async def fake_get(self, url, params=None): + class _R: + def raise_for_status(self): + pass + + def json(self): + return {"token": "t"} + + return _R() + + mocker.patch("httpx.AsyncClient.get", new=fake_get) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + await client.connect(_default_params()) + # Force the http client to be instantiated. + await client.create_temporary_token(expires_in_seconds=60) + + closed = [] + + async def fake_aclose(self): + closed.append(True) + + mocker.patch("httpx.AsyncClient.aclose", new=fake_aclose) + + await client.disconnect() + + assert closed == [True] + + +async def test_server_error_dedups_concurrent_write_side_close(mocker: MockFixture): + """Regression: a slow async ``on_error`` handler must not race a concurrent + write-side ``ConnectionClosed`` into a duplicate dispatch. The + ``_server_error_reported`` flag is set synchronously before the first + ``await`` in ``_report_server_error`` — this test locks in that ordering.""" + close_exc = ConnectionClosed(rcvd=Close(1011, "send-side close"), sent=None) + fake_ws = _FakeAsyncWebSocket(send_raises=close_exc) + _patch_connect(mocker, fake_ws) + + received = [] + handler_started = asyncio.Event() + handler_release = asyncio.Event() + + async def slow_on_error(_client, err): + received.append(err) + handler_started.set() + await handler_release.wait() + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Error, slow_on_error) + await client.connect(_default_params()) + + # Push a server Error frame; the read task enters the slow handler. + fake_ws.push_message( + json.dumps({"type": "Error", "error": "boom", "error_code": 4002}) + ) + await asyncio.wait_for(handler_started.wait(), timeout=1.0) + + # While the handler is parked, trigger a write-side close concurrently. + await client.stream(b"\x00" * 32) + for _ in range(50): + if fake_ws.send_call_count >= 1: + break + await asyncio.sleep(0.01) + + # Release the handler; the read task finishes dispatch and exits. + handler_release.set() + + await _wait_for_tasks(client) + + assert len(received) == 1, ( + f"expected exactly one on_error despite concurrent write-side close, " + f"got {received}" + ) + assert received[0].code == 4002 + + await client.disconnect() + + +async def test_disconnect_during_slow_handler_tears_down(mocker: MockFixture): + """Regression: ``disconnect()`` while an async handler is parked in a long + ``await`` must cleanly cancel the read task. ``CancelledError`` is a + ``BaseException`` (not ``Exception``), so it propagates through + ``_invoke_handler`` and out of the read task — ``disconnect()`` then + completes the cleanup.""" + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + handler_started = asyncio.Event() + + async def slow_handler(_client, _event): + handler_started.set() + await asyncio.sleep(60) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Begin, slow_handler) + await client.connect(_default_params()) + + fake_ws.push_message( + json.dumps({"type": "Begin", "id": "abc", "expires_at": "2030-01-01T00:00:00"}) + ) + await asyncio.wait_for(handler_started.wait(), timeout=1.0) + + # disconnect() should not hang waiting for the parked sleep — the read + # task is cancelled, CancelledError propagates, and disconnect returns. + await asyncio.wait_for(client.disconnect(), timeout=2.0) + + assert client._read_task.done() + + +async def test_write_side_close_is_dispatched_when_read_short_circuits_on_stop( + mocker: MockFixture, caplog: pytest.LogCaptureFixture +): + """Regression: if the read task observes ``_stop_event`` at the top of its + loop (e.g. after processing a buffered message) before its next ``recv()`` + raises, the write task must still dispatch the connection-closed event. + Previously the write task only set stop and exited, so this close went + unreported.""" + caplog.set_level(logging.ERROR) + + close_exc = ConnectionClosed(rcvd=Close(1011, "send-side close"), sent=None) + fake_ws = _FakeAsyncWebSocket(send_raises=close_exc) + _patch_connect(mocker, fake_ws) + + received = [] + + def on_error(_client, err): + received.append(err) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Error, on_error) + await client.connect(_default_params()) + + # Queue a write so the write task hits send() and raises ConnectionClosed. + await client.stream(b"\x00" * 32) + + # Wait for write task to finish dispatching the close. + for _ in range(200): + if received: + break + await asyncio.sleep(0.01) + + assert len(received) == 1, ( + f"expected exactly one on_error from write-side close, got {received}" + ) + assert received[0].code == 1011 + + await client.disconnect() diff --git a/tox.ini b/tox.ini index 3bfddd2..23daedf 100644 --- a/tox.ini +++ b/tox.ini @@ -27,7 +27,14 @@ deps = pytest-xdist pytest-mock pytest-cov + pytest-asyncio factory-boy allowlist_externals = pytest commands = pytest -n auto --cov-report term --cov-report xml:coverage.xml --cov=assemblyai + +[pytest] +# Streaming async tests use explicit ``pytestmark = pytest.mark.asyncio``. +# ``strict`` keeps that opt-in pattern and silences the pytest-asyncio +# unset-mode deprecation warning on >=0.21. +asyncio_mode = strict