Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(plugins): ai-prompt-decorator-plugin #12336

Merged
merged 4 commits into from Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/labeler.yml
Expand Up @@ -92,7 +92,11 @@ plugins/acme:

plugins/ai-proxy:
- changed-files:
- any-glob-to-any-file: ['kong/plugins/ai-proxy/**/*', 'kong/llm/**/*']
- any-glob-to-any-file: ['kong/plugins/ai-proxy/**/*', 'kong/llm/**/*']

plugins/ai-prompt-decorator:
- changed-files:
- any-glob-to-any-file: kong/plugins/ai-prompt-decorator/**/*

plugins/aws-lambda:
- changed-files:
Expand Down
3 changes: 3 additions & 0 deletions changelog/unreleased/kong/add-ai-prompt-decorator-plugin.yml
@@ -0,0 +1,3 @@
message: Introduced the new **AI Prompt Decorator** plugin that enables prepending and appending llm/v1/chat messages onto consumer LLM requests, for prompt tuning.
type: feature
scope: Plugin
3 changes: 3 additions & 0 deletions kong-3.6.0-0.rockspec
Expand Up @@ -574,6 +574,9 @@ build = {
["kong.llm.drivers.mistral"] = "kong/llm/drivers/mistral.lua",
["kong.llm.drivers.llama2"] = "kong/llm/drivers/llama2.lua",

["kong.plugins.ai-prompt-decorator.handler"] = "kong/plugins/ai-prompt-decorator/handler.lua",
["kong.plugins.ai-prompt-decorator.schema"] = "kong/plugins/ai-prompt-decorator/schema.lua",

["kong.vaults.env"] = "kong/vaults/env/init.lua",
["kong.vaults.env.schema"] = "kong/vaults/env/schema.lua",

Expand Down
1 change: 1 addition & 0 deletions kong/constants.lua
Expand Up @@ -37,6 +37,7 @@ local plugins = {
"zipkin",
"opentelemetry",
"ai-proxy",
"ai-prompt-decorator",
}

local plugin_map = {}
Expand Down
72 changes: 72 additions & 0 deletions kong/plugins/ai-prompt-decorator/handler.lua
@@ -0,0 +1,72 @@
local _M = {}

-- imports
local kong_meta = require "kong.meta"
local new_tab = require("table.new")
local EMPTY = {}
--

_M.PRIORITY = 772
_M.VERSION = kong_meta.version


local function bad_request(msg)
kong.log.debug(msg)
return kong.response.exit(400, { error = { message = msg } })
end

function _M.execute(request, conf)
local prepend = conf.prompts.prepend or EMPTY
local append = conf.prompts.append or EMPTY

if #prepend == 0 and #append == 0 then
return request, nil
end

local old_messages = request.messages
local new_messages = new_tab(#append + #prepend + #old_messages, 0)
request.messages = new_messages

local n = 0

for _, msg in ipairs(prepend) do
n = n + 1
new_messages[n] = { role = msg.role, content = msg.content }
end

for _, msg in ipairs(old_messages) do
n = n + 1
new_messages[n] = msg
end

for _, msg in ipairs(append) do
n = n + 1
new_messages[n] = { role = msg.role, content = msg.content }
end

return request, nil
end

function _M:access(conf)
kong.service.request.enable_buffering()
kong.ctx.shared.ai_prompt_decorated = true -- future use

-- if plugin ordering was altered, receive the "decorated" request
local request, err = kong.request.get_body("application/json")
if err then
return bad_request("this LLM route only supports application/json requests")
end

if not request.messages or #request.messages < 1 then
return bad_request("this LLM route only supports llm/chat type requests")
end

local decorated_request, err = self.execute(request, conf)
if err then
return bad_request(err)
end

kong.service.request.set_body(decorated_request, "application/json")
end
tysoekong marked this conversation as resolved.
Show resolved Hide resolved

return _M
50 changes: 50 additions & 0 deletions kong/plugins/ai-prompt-decorator/schema.lua
@@ -0,0 +1,50 @@
local typedefs = require "kong.db.schema.typedefs"

local prompt_record = {
type = "record",
required = false,
fields = {
{ role = { type = "string", required = true, one_of = { "system", "assistant", "user" }, default = "system" }},
{ content = { type = "string", required = true, len_min = 1, len_max = 500 } },
}
}

local prompts_record = {
type = "record",
required = false,
fields = {
{ prepend = {
type = "array",
description = "Insert chat messages at the beginning of the chat message array. "
.. "This array preserves exact order when adding messages.",
elements = prompt_record,
required = false,
len_max = 15,
}},
{ append = {
type = "array",
description = "Insert chat messages at the end of the chat message array. "
.. "This array preserves exact order when adding messages.",
elements = prompt_record,
required = false,
len_max = 15,
}},
}
}

return {
name = "ai-prompt-decorator",
fields = {
{ protocols = typedefs.protocols_http },
{ config = {
type = "record",
fields = {
{ prompts = prompts_record }
}
}
}
},
entity_checks = {
{ at_least_one_of = { "config.prompts.prepend", "config.prompts.append" } },
},
}
1 change: 1 addition & 0 deletions spec/01-unit/12-plugins_order_spec.lua
Expand Up @@ -72,6 +72,7 @@ describe("Plugins", function()
"response-ratelimiting",
"request-transformer",
"response-transformer",
"ai-prompt-decorator",
"ai-proxy",
"aws-lambda",
"azure-functions",
Expand Down
90 changes: 90 additions & 0 deletions spec/03-plugins/41-ai-prompt-decorator/00-config_spec.lua
@@ -0,0 +1,90 @@
local PLUGIN_NAME = "ai-prompt-decorator"


-- helper function to validate data against a schema
local validate do
local validate_entity = require("spec.helpers").validate_plugin_config_schema
local plugin_schema = require("kong.plugins."..PLUGIN_NAME..".schema")

function validate(data)
return validate_entity(data, plugin_schema)
end
end

describe(PLUGIN_NAME .. ": (schema)", function()
it("won't allow empty config object", function()
local config = {
}

local ok, err = validate(config)

assert.is_falsy(ok)
assert.not_nil(err)
assert.equal("at least one of these fields must be non-empty: 'config.prompts.prepend', 'config.prompts.append'", err["@entity"][1])
end)

it("won't allow both head and tail to be unset", function()
local config = {
prompts = {},
}

local ok, err = validate(config)

assert.is_falsy(ok)
assert.not_nil(err)
assert.equal("at least one of these fields must be non-empty: 'config.prompts.prepend', 'config.prompts.append'", err["@entity"][1])
end)

it("won't allow both allow_patterns and deny_patterns to be empty arrays", function()
local config = {
prompts = {
prepend = {},
append = {},
},
}

local ok, err = validate(config)

assert.is_falsy(ok)
assert.not_nil(err)
assert.equal("at least one of these fields must be non-empty: 'config.prompts.prepend', 'config.prompts.append'", err["@entity"][1])
end)

it("allows prepend only", function()
local config = {
prompts = {
prepend = {
[1] = {
role = "system",
content = "Prepend text 1 here.",
},
},
append = {},
},
}

local ok, err = validate(config)

assert.is_truthy(ok)
assert.is_nil(err)
end)

it("allows append only", function()
local config = {
prompts = {
prepend = {},
append = {
[1] = {
role = "system",
content = "Prepend text 1 here.",
},
},
},
}

local ok, err = validate(config)

assert.is_truthy(ok)
assert.is_nil(err)
end)
end)