Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 78 additions & 23 deletions src/providers/minimax_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,21 @@ def chat_stream_response(
abort_signal: Any = None,
**kwargs
) -> ChatResponse:
"""Stream Minimax response with abort-signal-aware cancellation.

Minimax wraps the anthropic SDK against its compatible endpoint,
so the response-close listener pattern AnthropicProvider uses
works here too. Same contract: pre-call fast-path, register-
then-recheck listener that closes the underlying HTTP response,
signal-state-authoritative abort detection in the exception
handler, post-with-block recheck, ``finally`` detaches the
listener.
"""
from src.utils.abort_controller import AbortError

# Pre-call fast-path: matches AnthropicProvider. A signal that
# tripped at a turn boundary skips the API round-trip entirely.
# Mid-stream cancellation isn't implemented yet — that needs the
# same response-close listener pattern AnthropicProvider uses,
# which the Minimax/anthropic-compatible SDK should support
# (it's the same underlying ``anthropic`` package) — future PR.
if abort_signal is not None and getattr(abort_signal, "aborted", False):
from src.utils.abort_controller import AbortError
raise AbortError(getattr(abort_signal, "reason", None) or "user_interrupt")
model = self._get_model(**kwargs)
max_tokens = kwargs.get("max_tokens", 4096)
Expand All @@ -188,24 +195,72 @@ def chat_stream_response(
extra_kwargs["tools"] = tools

streamed_text = ""
with client.messages.stream(
model=model,
max_tokens=max_tokens,
messages=minimax_messages,
**({"system": system} if system else {}),
**extra_kwargs,
**{k: v for k, v in kwargs.items() if k not in ["model", "max_tokens", "tools"]},
) as stream:
for text in stream.text_stream:
if not text:
continue
streamed_text += text
if on_text_chunk is not None:
on_text_chunk(text)
try:
final_message = stream.get_final_message()
except Exception:
final_message = None
final_message: Any = None
abort_listener: Any = None
try:
with client.messages.stream(
model=model,
max_tokens=max_tokens,
messages=minimax_messages,
**({"system": system} if system else {}),
**extra_kwargs,
**{k: v for k, v in kwargs.items() if k not in ["model", "max_tokens", "tools"]},
) as stream:
if abort_signal is not None:
def _close_stream_on_abort() -> None:
try:
response = getattr(stream, "response", None)
if response is not None:
close = getattr(response, "close", None)
if callable(close):
close()
except Exception:
pass

# Register-then-recheck: see AnthropicProvider for the
# full race analysis. ``_fire`` snapshots the listener
# list before iterating, so a listener appended after
# the snapshot is silently dropped; the post-add
# ``aborted`` read closes the gap.
abort_listener = abort_signal.add_listener(
_close_stream_on_abort, once=True,
)
if abort_signal.aborted:
_close_stream_on_abort()

for text in stream.text_stream:
if not text:
continue
streamed_text += text
if on_text_chunk is not None:
on_text_chunk(text)
try:
final_message = stream.get_final_message()
except Exception:
final_message = None
except Exception as streaming_exc:
# Abort path: signal state is authoritative — different SDK
# versions raise different exception types when the response
# is closed mid-read.
if abort_signal is not None and getattr(abort_signal, "aborted", False):
raise AbortError(
getattr(abort_signal, "reason", None) or "user_interrupt"
) from streaming_exc
raise
finally:
if abort_listener is not None and abort_signal is not None:
try:
abort_signal.remove_listener(abort_listener)
except Exception:
pass

# Stream completed normally but abort may have fired between
# ``__exit__`` and here. Surface it at the same boundary every
# other path uses.
if abort_signal is not None and getattr(abort_signal, "aborted", False):
raise AbortError(
getattr(abort_signal, "reason", None) or "user_interrupt"
)

if final_message is not None:
return self._build_chat_response(final_message)
Expand Down
185 changes: 138 additions & 47 deletions src/providers/openai_compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,34 @@ def chat_stream_response(
abort_signal: Any = None,
**kwargs
) -> ChatResponse:
"""Stream OpenAI-compatible chunks while rebuilding the final response."""
"""Stream OpenAI-compatible chunks while rebuilding the final response.

ESC-cancellation: when ``abort_signal`` is provided, two defenses
cooperate so ESC unwinds the stream promptly regardless of the
provider's chunk cadence:

* **Response-close listener** registered on the abort signal —
calls ``stream.response.close()``. Closes the underlying HTTP
socket so the SDK's blocking next-chunk read raises
immediately, even when the model is in a long gap between
chunks (extended thinking, tool_use generation).
* **In-loop abort check** at the top of each ``for chunk in
stream`` iteration — catches the case where chunks arrive
back-to-back and the listener's close lands one iteration
late, so we stop iterating before the next read.

Mirrors the contract ``AnthropicProvider.chat_stream_response``
established for the Anthropic SDK; same correctness arguments
apply (signal state is authoritative for abort detection;
register-then-recheck closes the registration race; listener
is detached in a ``finally`` so long-lived controllers don't
accumulate dead listeners).
"""
from src.utils.abort_controller import AbortError

# Pre-call fast-path: matches AnthropicProvider. A signal that
# tripped at a turn boundary skips the API round-trip entirely.
# Mid-stream cancellation isn't implemented yet — that needs a
# response-close listener around the OpenAI SDK's stream
# iterator — future PR.
if abort_signal is not None and getattr(abort_signal, "aborted", False):
from src.utils.abort_controller import AbortError
raise AbortError(getattr(abort_signal, "reason", None) or "user_interrupt")
model = self._get_model(**kwargs)
provider_messages = self._prepare_messages(messages)
Expand Down Expand Up @@ -372,51 +392,122 @@ def chat_stream_response(
usage_obj: Any = None
tool_calls_by_index: dict[int, dict[str, str]] = {}

for chunk in stream:
response_model = getattr(chunk, "model", response_model)
usage_candidate = getattr(chunk, "usage", None)
if usage_candidate is not None:
usage_obj = usage_candidate
# --- Abort-listener wiring ---
# Close the underlying HTTP response when the signal trips so a
# blocking next-chunk read raises immediately. The OpenAI Python
# SDK 1.x and 2.x both expose the underlying httpx Response as
# ``stream.response`` (see ``openai/_streaming.py``).
# ``httpx.Response.close()`` is idempotent (guarded by
# ``if not self.is_closed``), so a double-close — e.g., the
# listener fires AND the post-loop path explicitly closes — is
# harmless.
def _close_stream_on_abort() -> None:
try:
response = getattr(stream, "response", None)
if response is not None:
close = getattr(response, "close", None)
if callable(close):
close()
except Exception:
# Best-effort — never let close() propagate out of the
# listener thread.
pass

abort_listener: Any = None
if abort_signal is not None:
# Register-then-recheck: see the Anthropic provider for the
# full race analysis. The TL;DR is that ``_fire`` snapshots
# the listener list before iterating, so a listener appended
# after that snapshot is silently dropped; the post-add
# ``aborted`` read closes the gap (signal state is sticky).
abort_listener = abort_signal.add_listener(
_close_stream_on_abort, once=True,
)
if abort_signal.aborted:
_close_stream_on_abort()

try:
for chunk in stream:
# In-loop abort check: even when the listener fires
# mid-stream, chunks already buffered by the SDK can
# still get yielded before the closed-socket raise lands.
# The in-loop check makes the abort observable on the
# very next chunk boundary regardless of buffering.
if abort_signal is not None and abort_signal.aborted:
break

response_model = getattr(chunk, "model", response_model)
usage_candidate = getattr(chunk, "usage", None)
if usage_candidate is not None:
usage_obj = usage_candidate

choices = getattr(chunk, "choices", None) or []
if not choices:
continue
choice = choices[0]
if getattr(choice, "finish_reason", None):
finish_reason = choice.finish_reason

choices = getattr(chunk, "choices", None) or []
if not choices:
continue
choice = choices[0]
if getattr(choice, "finish_reason", None):
finish_reason = choice.finish_reason
delta = getattr(choice, "delta", None)
if delta is None:
continue

delta = getattr(choice, "delta", None)
if delta is None:
continue
content_piece = getattr(delta, "content", None)
if content_piece:
piece = str(content_piece)
content_parts.append(piece)
if on_text_chunk is not None:
on_text_chunk(piece)

reasoning_piece = getattr(delta, "reasoning_content", None)
if reasoning_piece:
reasoning_parts.append(str(reasoning_piece))

tool_call_deltas = getattr(delta, "tool_calls", None) or []
for tc in tool_call_deltas:
idx = getattr(tc, "index", 0)
entry = tool_calls_by_index.setdefault(idx, {"id": "", "name": "", "arguments": ""})

tc_id = getattr(tc, "id", None)
if tc_id:
entry["id"] = str(tc_id)

function = getattr(tc, "function", None)
if function is not None:
fn_name = getattr(function, "name", None)
if fn_name:
entry["name"] += str(fn_name)
fn_args = getattr(function, "arguments", None)
if fn_args:
entry["arguments"] += str(fn_args)
except Exception as streaming_exc:
# Abort path: the listener closed the underlying HTTP
# response, which raised on the SDK's next read in the
# consumer thread. Detect via signal state (not exception
# class — the OpenAI/httpx layer can raise several different
# exception types depending on which syscall was in flight).
if abort_signal is not None and getattr(abort_signal, "aborted", False):
raise AbortError(
getattr(abort_signal, "reason", None) or "user_interrupt"
) from streaming_exc
raise
finally:
if abort_listener is not None and abort_signal is not None:
try:
abort_signal.remove_listener(abort_listener)
except Exception:
pass

content_piece = getattr(delta, "content", None)
if content_piece:
piece = str(content_piece)
content_parts.append(piece)
if on_text_chunk is not None:
on_text_chunk(piece)

reasoning_piece = getattr(delta, "reasoning_content", None)
if reasoning_piece:
reasoning_parts.append(str(reasoning_piece))

tool_call_deltas = getattr(delta, "tool_calls", None) or []
for tc in tool_call_deltas:
idx = getattr(tc, "index", 0)
entry = tool_calls_by_index.setdefault(idx, {"id": "", "name": "", "arguments": ""})

tc_id = getattr(tc, "id", None)
if tc_id:
entry["id"] = str(tc_id)

function = getattr(tc, "function", None)
if function is not None:
fn_name = getattr(function, "name", None)
if fn_name:
entry["name"] += str(fn_name)
fn_args = getattr(function, "arguments", None)
if fn_args:
entry["arguments"] += str(fn_args)
# The stream may have completed naturally OR we broke out of
# the loop because the in-loop abort check fired. Surface the
# abort here so the caller bails at the same place every other
# path does. ``stream.close()`` after a clean exit is a no-op
# on httpx, so this stays safe.
if abort_signal is not None and getattr(abort_signal, "aborted", False):
_close_stream_on_abort()
raise AbortError(
getattr(abort_signal, "reason", None) or "user_interrupt"
)

tool_uses: list[dict[str, Any]] = []
for idx in sorted(tool_calls_by_index.keys()):
Expand Down
Loading