Skip to content

Commit

Permalink
fix(ai-proxy): check response tokens always string
Browse files Browse the repository at this point in the history
fix(ai-proxy): user cannot select their own model if one is defined
fix(ai-proxy): plugin config should own the tuning parameters
fix(ai-proxy): correct model check precedence
  • Loading branch information
tysoekong committed Jun 6, 2024
1 parent e5efc73 commit 0a93131
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 26 deletions.
5 changes: 5 additions & 0 deletions changelog/unreleased/kong/ai-proxy-fix-model-parameter.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
message: |
**AI-proxy-plugin**: Fix the bug where Cohere and Anthropic providers don't read the `model` parameter properly
from the caller's request body.
scope: Plugin
type: bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
message: |
**AI-proxy-plugin**: Fix the bug where using "OpenAI Function" inference requests would log a
request error, and then hang until timeout.
scope: Plugin
type: bugfix
5 changes: 5 additions & 0 deletions changelog/unreleased/kong/ai-proxy-fix-sending-own-model.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
message: |
**AI-proxy-plugin**: Fix a bug where AI Proxy would still allow callers to specify their own model,
ignoring the plugin-configured model name.
scope: Plugin
type: bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
message: |
**AI-proxy-plugin**: Fix a bug where AI Proxy would not take precedence of the plugin's
model tuning options, when configured, and instead would prefer the caller's JSON body parameters.
scope: Plugin
type: bugfix
5 changes: 0 additions & 5 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,6 @@ function _M.post_request(conf)
end

function _M.pre_request(conf, body)
-- check for user trying to bring own model
if body and body.model then
return nil, "cannot use own model for this instance"
end

return true, nil
end

Expand Down
11 changes: 1 addition & 10 deletions kong/llm/drivers/cohere.lua
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,6 @@ function _M.post_request(conf)
end

function _M.pre_request(conf, body)
-- check for user trying to bring own model
if body and body.model then
return false, "cannot use own model for this instance"
end

return true, nil
end

Expand Down Expand Up @@ -467,7 +462,7 @@ end
function _M.configure_request(conf)
local parsed_url

if conf.model.options.upstream_url then
if conf.model.options and conf.model.options.upstream_url then
parsed_url = socket_url.parse(conf.model.options.upstream_url)
else
parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME])
Expand All @@ -476,10 +471,6 @@ function _M.configure_request(conf)
or ai_shared.operation_map[DRIVER_NAME][conf.route_type]
and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path
or "/"

if not parsed_url.path then
return false, fmt("operation %s is not supported for cohere provider", conf.route_type)
end
end

-- if the path is read from a URL capture, ensure that it is valid
Expand Down
14 changes: 8 additions & 6 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ _M.clear_response_headers = {
-- @return {string} error if any is thrown - request should definitely be terminated if this is not nil
function _M.merge_config_defaults(request, options, request_format)
if options then
request.temperature = request.temperature or options.temperature
request.max_tokens = request.max_tokens or options.max_tokens
request.top_p = request.top_p or options.top_p
request.top_k = request.top_k or options.top_k
request.temperature = options.temperature or request.temperature
request.max_tokens = options.max_tokens or request.max_tokens
request.top_p = options.top_p or request.top_p
request.top_k = options.top_k or request.top_k
end

return request, nil
Expand Down Expand Up @@ -603,8 +603,10 @@ end
-- Function to count the number of words in a string
local function count_words(str)
local count = 0
for word in str:gmatch("%S+") do
count = count + 1
if type(str) == "string" then
for word in str:gmatch("%S+") do
count = count + 1
end
end
return count
end
Expand Down
14 changes: 12 additions & 2 deletions kong/plugins/ai-proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ end

local function get_token_text(event_t)
-- chat
return
local text =
event_t and
event_t.choices and
#event_t.choices > 0 and
Expand All @@ -52,6 +52,9 @@ local function get_token_text(event_t)
event_t.choices[1].text

or ""

-- sometimes user send/receive 'cjson.null' userdata, or random json object
return (type(text) == "string" and text) or ""
end

local function handle_streaming_frame(conf)
Expand Down Expand Up @@ -261,7 +264,7 @@ function _M:body_filter(conf)

local ai_driver = require("kong.llm.drivers." .. conf.model.provider)
local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type)

if err then
kong.log.warn("issue when transforming the response body for analytics in the body filter phase, ", err)
elseif new_response_string then
Expand Down Expand Up @@ -347,6 +350,13 @@ function _M:access(conf)
conf_m.model.name = "NOT_SPECIFIED"
end

-- check that the user isn't trying to override the plugin conf model in the request body
if request_table and request_table.model and type(request_table.model) == "string" then
if request_table.model ~= conf_m.model.name then
return bad_request("cannot use own model - must be: " .. conf_m.model.name)
end
end

-- model is stashed in the copied plugin conf, for consistency in transformation functions
if not conf_m.model.name then
return bad_request("model parameter not found in request, nor in gateway configuration")
Expand Down
6 changes: 3 additions & 3 deletions spec/03-plugins/38-ai-proxy/01-unit_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ describe(PLUGIN_NAME .. ": (unit)", function()
SAMPLE_LLM_V1_CHAT_WITH_SOME_OPTS,
{
max_tokens = 1024,
top_p = 1.0,
top_p = 0.5,
},
"llm/v1/chat"
)
Expand All @@ -638,9 +638,9 @@ describe(PLUGIN_NAME .. ": (unit)", function()

assert.is_nil(err)
assert.same({
max_tokens = 256,
max_tokens = 1024,
temperature = 0.1,
top_p = 0.2,
top_p = 0.5,
some_extra_param = "string_val",
another_extra_param = 0.5,
}, formatted)
Expand Down
15 changes: 15 additions & 0 deletions spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,21 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
}, json.choices[1].message)
end)

it("tries to override configured model", function()
local r = client:get("/openai/llm/v1/chat/good", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
},
body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_own_model.json"),
})

local body = assert.res_status(400 , r)
local json = cjson.decode(body)

assert.same(json, {error = { message = "cannot use own model - must be: gpt-3.5-turbo" } })
end)

it("bad upstream response", function()
local r = client:get("/openai/llm/v1/chat/bad_upstream_response", {
headers = {
Expand Down

0 comments on commit 0a93131

Please sign in to comment.