Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 136 additions & 56 deletions src/providers/openai_compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,13 +330,28 @@ def chat_stream_response(
) -> ChatResponse:
"""Stream OpenAI-compatible chunks while rebuilding the final response.

ESC-cancellation lives in ``StreamAbortGuard`` (see
``_stream_abort.py``). This provider keeps the SDK-specific
iteration shape — bare ``for chunk in stream`` plus an
in-loop ``guard.aborted`` check that catches the case where
chunks arrive back-to-back fast enough that the listener's
close lands one iteration late (or where the SDK has already
prefetched chunks past the close point).
ESC-cancellation runs the SDK iteration on a daemon worker
thread that pushes chunks into a ``queue.Queue``. The main
thread polls the queue with a 100 ms timeout and re-checks
``guard.aborted`` between ticks. On abort the main thread
raises ``AbortError`` immediately and orphans the worker —
the worker dies when the underlying connection eventually
closes.

Why the worker indirection (vs. the simpler in-loop check
used in earlier revisions): the OpenAI Python SDK uses sync
``httpx`` for streaming, and ``response.close()`` from
another thread is purely advisory. For LiteLLM-proxied
connections (and certain other httpx + chunked-transfer
configurations) the SDK's blocking socket read doesn't
actually return when the response is "closed" — it keeps
consuming bytes. Unlike JavaScript's native ``fetch +
AbortSignal`` integration (which the TypeScript reference at
``typescript/src/services/api/openaiShim.ts`` uses), Python
has no portable way to make a sync blocking read honor an
abort from another thread, so the worker exists to keep the
main thread's response time independent of the SDK's
cooperation.
"""
from ._stream_abort import StreamAbortGuard

Expand Down Expand Up @@ -378,63 +393,128 @@ def chat_stream_response(
usage_obj: Any = None
tool_calls_by_index: dict[int, dict[str, str]] = {}

with guard.attach(stream):
# Worker-thread iteration. The OpenAI Python SDK uses sync
# ``httpx`` for streaming, and ``response.close()`` from another
# thread is best-effort — for LiteLLM-proxied connections (and
# some other httpx configurations) the SDK's blocking socket
# read doesn't actually return when the response is closed.
# Unlike JavaScript's native ``fetch + AbortSignal`` integration
# (which the TypeScript reference uses), Python has no portable
# way to make a sync blocking read honor an abort from another
# thread.
#
# Workaround: hoist the iteration onto a daemon worker thread
# that pushes chunks into a queue. The main thread polls the
# queue with a short timeout and re-checks ``guard.aborted``
# each tick. On abort we raise ``AbortError`` immediately and
# orphan the worker — it'll die when the underlying connection
# eventually closes (server-side, idle timeout, or the SDK's
# natural exhaustion). The cost is some wasted bandwidth on
# the orphaned read; the benefit is that the user's prompt
# comes back in ~100 ms regardless of LiteLLM/httpx behavior.
import queue as _queue
import threading as _threading

_DONE = object()
chunk_queue: _queue.Queue = _queue.Queue()

def _drain_stream() -> None:
try:
for chunk in stream:
# In-loop check catches the SDK-prefetched-chunks
# case: the listener's close lands but the SDK has
# already buffered several chunks ahead. We break
# before consuming the next one.
if guard.aborted:
break
for c in stream:
chunk_queue.put(c)
except BaseException as exc: # noqa: BLE001 — surface to consumer
chunk_queue.put(exc)
finally:
chunk_queue.put(_DONE)

worker = _threading.Thread(
target=_drain_stream,
daemon=True,
name=f"openai-stream-{id(stream)}",
)

response_model = getattr(chunk, "model", response_model)
usage_candidate = getattr(chunk, "usage", None)
if usage_candidate is not None:
usage_obj = usage_candidate
with guard.attach(stream):
worker.start()
while True:
try:
item = chunk_queue.get(timeout=0.1)
except _queue.Empty:
# No chunk available right now — check abort and
# loop. The 100 ms tick bounds how long the user
# waits between pressing ESC and the prompt
# returning, regardless of how slow / blocked the
# underlying SDK iteration is.
if guard.aborted:
# Use ``raise_if_post_aborted`` so the abort
# reason from the controller is preserved
# (rather than hardcoding ``"user_interrupt"``,
# which would silently downgrade a non-default
# reason like a future ``"rate_limit_backoff"``).
guard.raise_if_post_aborted()
continue

choices = getattr(chunk, "choices", None) or []
if not choices:
continue
if item is _DONE:
break
if isinstance(item, BaseException):
if isinstance(item, Exception):
guard.reraise_if_aborted(item)
raise item
# KeyboardInterrupt/SystemExit from the worker
# path — re-raise as-is so the outer signal-
# handling story stays intact.
raise item

chunk = item
response_model = getattr(chunk, "model", response_model)
usage_candidate = getattr(chunk, "usage", None)
if usage_candidate is not None:
usage_obj = usage_candidate

choices = getattr(chunk, "choices", None) or []
if choices:
choice = choices[0]
if getattr(choice, "finish_reason", None):
finish_reason = choice.finish_reason

delta = getattr(choice, "delta", None)
if delta is None:
continue

content_piece = getattr(delta, "content", None)
if content_piece:
piece = str(content_piece)
content_parts.append(piece)
if on_text_chunk is not None:
on_text_chunk(piece)

reasoning_piece = getattr(delta, "reasoning_content", None)
if reasoning_piece:
reasoning_parts.append(str(reasoning_piece))

tool_call_deltas = getattr(delta, "tool_calls", None) or []
for tc in tool_call_deltas:
idx = getattr(tc, "index", 0)
entry = tool_calls_by_index.setdefault(idx, {"id": "", "name": "", "arguments": ""})

tc_id = getattr(tc, "id", None)
if tc_id:
entry["id"] = str(tc_id)

function = getattr(tc, "function", None)
if function is not None:
fn_name = getattr(function, "name", None)
if fn_name:
entry["name"] += str(fn_name)
fn_args = getattr(function, "arguments", None)
if fn_args:
entry["arguments"] += str(fn_args)
except Exception as streaming_exc:
guard.reraise_if_aborted(streaming_exc)
raise
if delta is not None:
content_piece = getattr(delta, "content", None)
if content_piece:
piece = str(content_piece)
content_parts.append(piece)
if on_text_chunk is not None:
on_text_chunk(piece)

reasoning_piece = getattr(delta, "reasoning_content", None)
if reasoning_piece:
reasoning_parts.append(str(reasoning_piece))

tool_call_deltas = getattr(delta, "tool_calls", None) or []
for tc in tool_call_deltas:
idx = getattr(tc, "index", 0)
entry = tool_calls_by_index.setdefault(idx, {"id": "", "name": "", "arguments": ""})

tc_id = getattr(tc, "id", None)
if tc_id:
entry["id"] = str(tc_id)

function = getattr(tc, "function", None)
if function is not None:
fn_name = getattr(function, "name", None)
if fn_name:
entry["name"] += str(fn_name)
fn_args = getattr(function, "arguments", None)
if fn_args:
entry["arguments"] += str(fn_args)

# Check abort AFTER processing this chunk so any
# already-delivered content is preserved (matches the
# in-loop-check semantics from the old implementation:
# the chunk-list test pins that the chunk we received
# before the abort gets processed; we just don't take
# the next one).
if guard.aborted:
guard.raise_if_post_aborted()

# Stream completed naturally OR the in-loop check broke out.
# In the latter case the signal is already tripped; raise so
Expand Down
126 changes: 126 additions & 0 deletions tests/test_openai_compat_abort_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,129 @@ def test_listener_detached_after_normal_completion() -> None:
)

assert controller.signal._listeners == []


class _StuckStream:
"""Mimic an OpenAI Stream whose iterator never honors ``response.close()``.

Models the LiteLLM/proxy scenario reported by the user: the
underlying socket is not interrupted when ``stream.response.close()``
is called from another thread, so the SDK iterator stays blocked
on the next chunk indefinitely. The worker-thread iteration in
``OpenAICompatibleProvider.chat_stream_response`` must NOT rely on
the iterator unblocking — the main thread polls a queue with
timeout and bails on abort.

``__iter__`` blocks on an ``Event`` that the test never sets, so
iteration would hang forever without the worker+queue decoupling.
"""

def __init__(self) -> None:
self.response = MagicMock()
self._never_set = threading.Event()
self._iter_entered = threading.Event()

def __iter__(self):
self._iter_entered.set()
# Block forever — even if response.close() is called.
# ``_never_set`` is never set in this test.
self._never_set.wait()
# Unreachable. If we somehow get here, yield nothing so the
# iterator ends and the test doesn't go on forever.
return
yield # pragma: no cover


def test_abort_unwinds_promptly_even_when_iterator_never_returns() -> None:
"""The user's bug: ESC must unwind in <1s even when the SDK never honors close().

Pre-fix (single-threaded ``for chunk in stream``): the main thread
was blocked on ``next(stream)`` waiting for a chunk the LiteLLM
proxy never delivered, ``response.close()`` from the listener
thread didn't propagate to the kernel socket read, and ESC waited
indefinitely.

Post-fix (worker thread + queue): the SDK iteration runs on a
daemon worker that gets orphaned on abort. The main thread polls
the queue with a 100 ms timeout and bails on ``guard.aborted``.
Total ESC-to-AbortError budget is one poll tick plus listener
cascade — well under 1 second on any reasonable machine.

Failure mode this regression-tests against: someone reverting the
worker+queue would make the main thread block on ``next(stream)``
again. With ``_StuckStream``'s never-set Event, the test would
hang forever (the assertion-failure form is a CI timeout, not a
fast fail — but a CI timeout is still loud).
"""
controller = AbortController()
stream = _StuckStream()
provider = _provider_with_stream(stream)

def _trip_after_worker_starts() -> None:
# Wait for the worker thread to actually enter the iterator,
# so the test pins "abort during a stuck iteration" rather
# than "abort before the worker started".
assert stream._iter_entered.wait(timeout=2.0), "worker never entered iterator"
controller.abort("user_interrupt")

threading.Thread(target=_trip_after_worker_starts, daemon=True).start()

start = time.monotonic()
with pytest.raises(AbortError):
provider.chat_stream_response(
messages=[{"role": "user", "content": "hi"}],
abort_signal=controller.signal,
)
elapsed = time.monotonic() - start

# 100 ms poll tick + listener cascade + abort propagation. 1.5 s
# is comfortable headroom on slow CI; on a healthy laptop this is
# well under 300 ms.
assert elapsed < 1.5, f"abort took {elapsed:.2f}s — expected <1.5s"


class _ContentThenUsageStream:
"""Stream that yields one content chunk then a final usage-only chunk.

Mirrors OpenAI's streaming wire format when
``stream_options.include_usage=True``: content/delta chunks first,
then a final chunk with empty ``choices`` and populated ``usage``.
"""

def __init__(self) -> None:
self.response = MagicMock()

def __iter__(self):
# Regular content chunk.
yield _FakeChunk(content="hello")
# Final usage-only chunk: empty choices, populated usage.
final = MagicMock()
final.model = "test-model"
final.choices = []
final.usage = MagicMock(
prompt_tokens=10, completion_tokens=5, total_tokens=15,
)
yield final


def test_normal_completion_still_captures_final_usage() -> None:
"""The worker+queue path must not drop the final usage chunk.

OpenAI emits usage stats only in the last chunk (with empty
``choices``). The main thread must drain every queued chunk
before breaking on ``_DONE`` — otherwise token counting would
silently regress for non-aborted streams.
"""
controller = AbortController()
stream = _ContentThenUsageStream()
provider = _provider_with_stream(stream)

response = provider.chat_stream_response(
messages=[{"role": "user", "content": "hi"}],
abort_signal=controller.signal,
)
assert response.content == "hello"
# The final usage chunk made it through the queue; otherwise
# ``response.usage`` would be the default empty dict, and the
# ``↓ N tokens`` REPL spinner would silently lose count.
assert response.usage.get("total_tokens") == 15