diff --git a/apisix/plugin.lua b/apisix/plugin.lua index 723ededd0b72..73d4df77fdbb 100644 --- a/apisix/plugin.lua +++ b/apisix/plugin.lua @@ -1397,13 +1397,23 @@ function _M.run_global_rules(api_ctx, global_rules, conf_version, phase_name) end end -function _M.lua_response_filter(api_ctx, headers, body) +-- @param wait boolean When true, use synchronous flush (ngx.flush(true)) so callers +-- can detect client disconnection. Defaults to false (async flush). +-- @return boolean, string|nil Always returns (ok, err). On success returns true. +-- On flush failure or print failure returns false, err. +function _M.lua_response_filter(api_ctx, headers, body, wait) local plugins = api_ctx.plugins if not plugins or #plugins == 0 then -- if there is no any plugin, just print the original body to downstream - ngx_print(body) - ngx_flush() - return + local ok, err = ngx_print(body) + if not ok then + return false, err + end + ok, err = ngx_flush(wait == true) + if not ok then + return false, err + end + return true end for i = 1, #plugins, 2 do local phase_func = plugins[i]["lua_body_filter"] @@ -1430,8 +1440,15 @@ function _M.lua_response_filter(api_ctx, headers, body) ::CONTINUE:: end - ngx_print(body) - ngx_flush() + local ok, err = ngx_print(body) + if not ok then + return false, err + end + ok, err = ngx_flush(wait == true) + if not ok then + return false, err + end + return true end diff --git a/apisix/plugins/ai-providers/base.lua b/apisix/plugins/ai-providers/base.lua index cb49263f558b..83ac9447b61e 100644 --- a/apisix/plugins/ai-providers/base.lua +++ b/apisix/plugins/ai-providers/base.lua @@ -333,6 +333,16 @@ function _M.parse_streaming_response(self, ctx, res, target_proto, converter, co -- uncommitted and causing nginx to fall through to the balancer phase. local output_sent = false + local function abort_on_disconnect(flush_err) + core.log.info("client disconnected during AI streaming, ", + "aborting upstream read: ", flush_err) + if res._httpc then + res._httpc:close() + res._httpc = nil + end + ctx.var.llm_request_done = true + end + -- Runaway-upstream safeguards. Both are opt-in; unset means no cap. local max_duration_ms = conf and conf.max_stream_duration_ms local max_bytes = conf and conf.max_response_bytes @@ -424,15 +434,24 @@ function _M.parse_streaming_response(self, ctx, res, target_proto, converter, co ::CONTINUE:: end - -- Output: converter events or passthrough raw chunk + -- Output: converter events or passthrough raw chunk. + -- Pass wait=true for synchronous flush so we can detect client disconnection. if converter then for _, c in ipairs(converted_chunks) do - plugin.lua_response_filter(ctx, res.headers, c) + local ok, flush_err = plugin.lua_response_filter(ctx, res.headers, c, true) output_sent = true + if not ok then + abort_on_disconnect(flush_err) + return + end end else - plugin.lua_response_filter(ctx, res.headers, chunk) + local ok, flush_err = plugin.lua_response_filter(ctx, res.headers, chunk, true) output_sent = true + if not ok then + abort_on_disconnect(flush_err) + return + end end -- Enforce runaway-upstream safeguards after processing the chunk. diff --git a/t/cli/test_dns.sh b/t/cli/test_dns.sh index f0e19a837597..e52d6530a0ef 100755 --- a/t/cli/test_dns.sh +++ b/t/cli/test_dns.sh @@ -158,6 +158,7 @@ curl -v -k -i -m 20 -o /dev/null -s -X PUT http://127.0.0.1:9180/apisix/admin/st } }' +sleep 1 # wait for the stream route to propagate from etcd to stream workers curl http://127.0.0.1:9100 || true make stop sleep 0.1 # wait for logs output diff --git a/t/plugin/ai-proxy-anthropic.t b/t/plugin/ai-proxy-anthropic.t index 3223785b0453..5eb0b75d072a 100644 --- a/t/plugin/ai-proxy-anthropic.t +++ b/t/plugin/ai-proxy-anthropic.t @@ -442,7 +442,7 @@ Content-Type: application/json test-type: null-details --- error_code: 200 --- response_body_like eval -qr/"input_tokens":10.*"output_tokens":5/ +qr/(?s)(?=.*"input_tokens":10)(?=.*"output_tokens":5)/ --- no_error_log [error] diff --git a/t/plugin/ai-proxy-client-disconnect.t b/t/plugin/ai-proxy-client-disconnect.t new file mode 100644 index 000000000000..654da7e27b3f --- /dev/null +++ b/t/plugin/ai-proxy-client-disconnect.t @@ -0,0 +1,224 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +use t::APISIX 'no_plan'; + +log_level("info"); +repeat_each(1); +no_long_string(); +no_root_location(); + + +add_block_preprocessor(sub { + my ($block) = @_; + + if (!defined $block->request) { + $block->set_value("request", "GET /t"); + } + + # Mock upstream: slow SSE server that streams chunks until the connection + # is closed, tracking the final chunk count in the "test" shared dict. + my $http_config = $block->http_config // <<_EOC_; + server { + server_name slow_openai_sse; + listen 7750; + + default_type 'text/event-stream'; + + location /v1/chat/completions { + content_by_lua_block { + ngx.header["Content-Type"] = "text/event-stream" + local dict = ngx.shared["test"] + dict:set("upstream_chunks", 0) + -- Stream up to 2000 chunks with 30ms sleep between each. + -- The proxy should abort well before this completes when + -- the client disconnects. + for i = 1, 2000 do + local ok, err = ngx.print( + 'data: {"id":"chatcmpl-1","object":' + .. '"chat.completion.chunk","choices":[{"delta":' + .. '{"content":"tok"},"index":0,' + .. '"finish_reason":null}],"usage":null}\\n\\n') + if not ok then + return + end + local flush_ok = ngx.flush(true) + if not flush_ok then + return + end + dict:set("upstream_chunks", i) + ngx.sleep(0.03) + end + } + } + + # Probe endpoint to read the current chunk count. + location /chunks { + content_by_lua_block { + local dict = ngx.shared["test"] + ngx.say(dict:get("upstream_chunks") or 0) + } + } + } +_EOC_ + + $block->set_value("http_config", $http_config); +}); + + +run_tests(); + +__DATA__ + +=== TEST 1: set route for client disconnect test +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy": { + "provider": "openai", + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "gpt-4", + "stream": true + }, + "override": { + "endpoint": "http://localhost:7750" + }, + "ssl_verify": false + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 2: client disconnect aborts upstream read early +--- config + location /t { + content_by_lua_block { + local http = require("resty.http") + local httpc = http.new() + + local ok, err = httpc:connect({ + scheme = "http", + host = "localhost", + port = ngx.var.server_port, + }) + if not ok then + ngx.status = 500 + ngx.say("connect failed: ", err) + return + end + + local res, err = httpc:request({ + method = "POST", + headers = { ["Content-Type"] = "application/json" }, + path = "/anything", + body = [[{"messages": [{"role": "user", "content": "hi"}]}]], + }) + if not res then + ngx.status = 500 + ngx.say("request failed: ", err) + return + end + + -- Read exactly 3 chunks then close the connection abruptly. + for i = 1, 3 do + local chunk, rerr = res.body_reader() + if rerr or not chunk then + ngx.status = 500 + ngx.say("unexpected end of stream at chunk ", i, ": ", rerr) + return + end + end + httpc:close() + + -- Allow time for the proxy to detect the disconnect and stop + -- feeding the upstream connection, then capture the chunk count. + -- 1s window: unfixed path produces ~33 chunks (1000ms / 30ms per + -- chunk); fixed path stops within a few chunks of the disconnect. + ngx.sleep(1.0) + + -- Read chunk count from the mock upstream's probe endpoint. + local probe = http.new() + ok, err = probe:connect({ scheme = "http", host = "localhost", port = 7750 }) + if not ok then + ngx.status = 500 + ngx.say("probe connect failed: ", err) + return + end + local probe_res, probe_err = probe:request({ + method = "GET", + path = "/chunks", + headers = { Host = "localhost" }, + }) + if not probe_res then + ngx.status = 500 + ngx.say("probe request failed: ", probe_err) + return + end + local count_str = probe_res:read_body() + probe:close() + + if probe_res.status ~= 200 then + ngx.status = 500 + ngx.say("probe status unexpected: ", probe_res.status) + return + end + + local count = tonumber(count_str) + if not count then + ngx.status = 500 + ngx.say("invalid probe response: ", count_str or "nil") + return + end + + -- With the fix the upstream stops shortly after client disconnect + -- (well under 15 chunks). Without the fix it reaches ~33 chunks in + -- the 1s observation window, so this threshold reliably catches the + -- regression while leaving ample headroom for timing variation. + if count > 15 then + ngx.status = 500 + ngx.say("upstream was not aborted promptly, chunks: ", count) + return + end + ngx.say("ok, upstream aborted after ~", count, " chunks") + } + } +--- response_body_like +^ok, upstream aborted after ~\d+ chunks$ +--- error_log +client disconnected during AI streaming