Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion assemblyai/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.63.1"
__version__ = "0.64.0"
4 changes: 4 additions & 0 deletions assemblyai/streaming/v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
StreamingError,
StreamingEvents,
StreamingParameters,
StreamingPiiPolicy,
StreamingPiiSubstitution,
StreamingSessionParameters,
TerminationEvent,
TurnEvent,
Expand All @@ -31,6 +33,8 @@
"StreamingError",
"StreamingEvents",
"StreamingParameters",
"StreamingPiiPolicy",
"StreamingPiiSubstitution",
"StreamingSessionParameters",
"TerminationEvent",
"TurnEvent",
Expand Down
241 changes: 194 additions & 47 deletions assemblyai/streaming/v3/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def _dump_model(model: BaseModel):
return model.dict(exclude_none=True)


def _parse_model(model_class, data):
if hasattr(model_class, "model_validate"):
return model_class.model_validate(data)
return model_class.parse_obj(data)


def _normalize_min_turn_silence(params_dict: dict) -> dict:
"""Collapse `min_end_of_turn_silence_when_confident` into `min_turn_silence` so only
one wire key is ever sent. Emits deprecation warnings."""
Expand All @@ -65,6 +71,31 @@ def _normalize_min_turn_silence(params_dict: dict) -> dict:
return params_dict


def _normalize_voice_focus(params_dict: dict) -> dict:
"""Collapse `noise_suppression_model` / `noise_suppression_threshold` into
`voice_focus` / `voice_focus_threshold` so only the new wire keys are sent.
Emits deprecation warnings."""
for old_key, new_key in (
("noise_suppression_model", "voice_focus"),
("noise_suppression_threshold", "voice_focus_threshold"),
):
old = params_dict.pop(old_key, None)
if old is None:
continue
if new_key in params_dict:
logger.warning(
f"[Deprecation Warning] Both `{old_key}` and `{new_key}` are set. "
f"Using `{new_key}`; `{old_key}` is deprecated."
)
else:
logger.warning(
f"[Deprecation Warning] `{old_key}` is deprecated and will be removed "
f"in a future release. Please use `{new_key}` instead."
)
params_dict[new_key] = old
return params_dict


def _dump_model_json(model: BaseModel):
if hasattr(model, "model_dump_json"):
return model.model_dump_json(exclude_none=True)
Expand Down Expand Up @@ -94,6 +125,17 @@ def __init__(self, options: StreamingClientOptions):
self._write_thread = threading.Thread(target=self._write_message)
self._read_thread = threading.Thread(target=self._read_message)
self._stop_event = threading.Event()
# Both flags are read and set only on the read thread (or on the main
# thread before workers start, for handshake errors). Plain bools are
# sufficient — no cross-thread synchronization is needed.
self._connection_closed_reported = False
self._server_error_reported = False
# Deliberate single-slot shared-memory handoff: the write thread parks
# a ConnectionClosed here and the read thread drains it. Synchronization
# is provided by `_stop_event.set()` (write side) + `recv(timeout=1)`
# (read side), which together give a happens-before within ~1s.
self._pending_close_error: Optional[Exception] = None
self._websocket = None

def connect(self, params: StreamingParameters) -> None:
if params.speech_model == "u3-pro":
Expand All @@ -102,7 +144,15 @@ def connect(self, params: StreamingParameters) -> None:
"Please use `u3-rt-pro` instead."
)

params_dict = _normalize_min_turn_silence(_dump_model(params))
if params.customer_support_audio_capture:
logger.warning(
"`customer_support_audio_capture=True` will record session audio. "
"Only enable this when explicitly coordinating with AssemblyAI support."
)

params_dict = _normalize_voice_focus(
_normalize_min_turn_silence(_dump_model(params))
)

# JSON-encode list and dict parameters for proper API compatibility (e.g., keyterms_prompt, llm_gateway)
for key, value in params_dict.items():
Expand Down Expand Up @@ -132,8 +182,22 @@ def connect(self, params: StreamingParameters) -> None:
additional_headers=headers,
open_timeout=15,
)
except websockets.exceptions.ConnectionClosed as exc:
self._handle_error(exc)
except websockets.exceptions.InvalidStatus as exc:
status_code = getattr(getattr(exc, "response", None), "status_code", None)
self._report_connection_closed(
StreamingError(
message=f"WebSocket handshake rejected (HTTP {status_code})",
code=status_code,
)
)
return
except (
websockets.exceptions.InvalidHandshake,
websockets.exceptions.ConnectionClosed,
OSError,
TimeoutError,
) as exc:
self._report_connection_closed(exc)
return

self._write_thread.start()
Expand All @@ -145,23 +209,40 @@ def disconnect(self, terminate: bool = False) -> None:
if terminate and not self._stop_event.is_set():
self._write_queue.put(TerminateSession())

try:
self._read_thread.join()
self._write_thread.join()
self._stop_event.set()

current = threading.current_thread()
for thread in (self._read_thread, self._write_thread):
if thread is current or not thread.is_alive():
continue
try:
thread.join()
except RuntimeError as exc:
logger.debug("Thread join skipped: %s", exc)

if self._websocket:
self._websocket.close()
except Exception:
pass
self._close_websocket()

def _close_websocket(self) -> None:
if not self._websocket:
return
try:
self._websocket.close()
except (OSError, websockets.exceptions.WebSocketException) as exc:
logger.debug("Error closing websocket: %s", exc)

def stream(
self, data: Union[bytes, Generator[bytes, None, None], Iterable[bytes]]
) -> None:
if self._stop_event.is_set():
return

if isinstance(data, bytes):
self._write_queue.put(data)
return

for chunk in data:
if self._stop_event.is_set():
return
self._write_queue.put(chunk)

def set_params(self, params: StreamingSessionParameters):
Expand All @@ -178,15 +259,24 @@ def on(self, event: StreamingEvents, handler: Callable) -> None:
self._handlers[event].append(handler)

def _write_message(self) -> None:
while not self._stop_event.is_set():
while True:
if not self._websocket:
raise ValueError("Not connected to the WebSocket server")

try:
data = self._write_queue.get(timeout=1)
except queue.Empty:
if self._stop_event.is_set():
return
continue

# TerminateSession bypasses the stop gate so disconnect(terminate=True)
# can always send it, even when stop is set between put() and the
# write loop's next iteration.
is_terminate = isinstance(data, TerminateSession)
if not is_terminate and self._stop_event.is_set():
return

try:
if isinstance(data, bytes):
self._websocket.send(data)
Expand All @@ -195,20 +285,36 @@ def _write_message(self) -> None:
else:
raise ValueError(f"Attempted to send invalid message: {type(data)}")
except websockets.exceptions.ConnectionClosed as exc:
self._handle_error(exc)
# Defer reporting to the read thread so all on_error dispatch
# happens on a single thread (no cross-thread dedup race).
self._pending_close_error = exc
self._stop_event.set()
return

if is_terminate:
return

def _read_message(self) -> None:
while not self._stop_event.is_set():
while True:
if not self._websocket:
raise ValueError("Not connected to the WebSocket server")

# Drain a write-thread close before honoring stop, so a stop set by
# the write thread doesn't cause us to exit silently with an
# unreported close.
if self._pending_close_error is not None:
pending, self._pending_close_error = self._pending_close_error, None
self._report_connection_closed(pending)
return
if self._stop_event.is_set():
return

try:
message_data = self._websocket.recv(timeout=1)
except TimeoutError:
continue
except websockets.exceptions.ConnectionClosed as exc:
self._handle_error(exc)
self._report_connection_closed(exc)
return

try:
Expand All @@ -220,7 +326,7 @@ def _read_message(self) -> None:
message = self._parse_message(message_json)

if isinstance(message, ErrorEvent):
self._handle_error(message)
self._report_server_error(message)
elif isinstance(message, WarningEvent):
self._handle_warning(message)
elif message:
Expand All @@ -244,23 +350,23 @@ def _parse_message(self, data: Dict[str, Any]) -> Optional[EventMessage]:
event_type = self._parse_event_type(message_type)

if event_type == StreamingEvents.Begin:
return BeginEvent.model_validate(data)
return _parse_model(BeginEvent, data)
elif event_type == StreamingEvents.Termination:
return TerminationEvent.model_validate(data)
return _parse_model(TerminationEvent, data)
elif event_type == StreamingEvents.Turn:
return TurnEvent.model_validate(data)
return _parse_model(TurnEvent, data)
elif event_type == StreamingEvents.SpeechStarted:
return SpeechStartedEvent.model_validate(data)
return _parse_model(SpeechStartedEvent, data)
elif event_type == StreamingEvents.LLMGatewayResponse:
return LLMGatewayResponseEvent.model_validate(data)
return _parse_model(LLMGatewayResponseEvent, data)
elif event_type == StreamingEvents.Error:
return ErrorEvent.model_validate(data)
return _parse_model(ErrorEvent, data)
elif event_type == StreamingEvents.Warning:
return WarningEvent.model_validate(data)
return _parse_model(WarningEvent, data)
else:
return None
elif "error" in data:
return ErrorEvent.model_validate(data)
return _parse_model(ErrorEvent, data)

return None

Expand All @@ -281,44 +387,85 @@ def _handle_warning(self, warning: WarningEvent):
for handler in self._handlers[StreamingEvents.Warning]:
handler(self, warning)

def _handle_error(
def _report_server_error(self, error: ErrorEvent) -> None:
self._server_error_reported = True
streaming_error = StreamingError(
message=error.error,
code=error.error_code,
)
logger.error("Streaming error: %s (code=%s)", error.error, error.error_code)
self._dispatch_error(streaming_error)

def _report_connection_closed(
self,
error: Union[
StreamingError,
ErrorEvent,
websockets.exceptions.ConnectionClosed,
OSError,
],
):
parsed_error = self._parse_error(error)
) -> None:
# Idempotent: defensive guard in case future callers (e.g. another
# connect-time error path) reach this method twice.
if self._connection_closed_reported:
return
self._connection_closed_reported = True
self._stop_event.set()

for handler in self._handlers[StreamingEvents.Error]:
handler(self, parsed_error)
streaming_error = self._build_connection_closed_error(error)

self.disconnect()
# Clean close (code 1000) → no streaming_error, nothing to report.
if streaming_error is None:
self._close_websocket()
return

def _parse_error(
self,
if isinstance(error, websockets.exceptions.ConnectionClosed):
reason = error.reason or "no reason given"
logger.error("Connection closed: %s (code=%s)", reason, error.code)
else:
logger.error(
"Connection failed: %s (code=%s)",
streaming_error,
streaming_error.code,
)

# If a server Error frame already fired on_error, the close is the
# effect, not a new cause — log it (above) but skip the duplicate
# user-visible error.
if not self._server_error_reported:
self._dispatch_error(streaming_error)

self._close_websocket()

def _dispatch_error(self, error: StreamingError) -> None:
for handler in self._handlers[StreamingEvents.Error]:
try:
handler(self, error)
except Exception:
logger.exception("on_error handler raised")

@staticmethod
def _build_connection_closed_error(
error: Union[
StreamingError,
ErrorEvent,
websockets.exceptions.ConnectionClosed,
OSError,
],
) -> StreamingError:
) -> Optional[StreamingError]:
if isinstance(error, StreamingError):
return error
if isinstance(error, ErrorEvent):
return StreamingError(
message=error.error,
code=error.error_code,
)
elif isinstance(error, websockets.exceptions.ConnectionClosed):
if error.code in StreamingErrorCodes:
error_message = StreamingErrorCodes[error.code]
return StreamingError(message=error.error, code=error.error_code)
if isinstance(error, websockets.exceptions.ConnectionClosed):
if error.code == 1000:
return None
if error.code is not None and error.code in StreamingErrorCodes:
message = StreamingErrorCodes[error.code]
else:
error_message = error.reason

if error.code != 1000:
return StreamingError(message=error_message, code=error.code)

return StreamingError(
message=f"Unknown error: {error}",
)
message = error.reason or f"Connection closed (code={error.code})"
return StreamingError(message=message, code=error.code)
return StreamingError(message=f"Connection failed: {error}")

def create_temporary_token(
self,
Expand Down
Loading
Loading