Skip to content

Commit

Permalink
feat(plugins): ai-prompt-decorator-plugin (#12336)
Browse files Browse the repository at this point in the history
* feat(plugins): ai-prompt-decorator-plugin

* fix(ai-prompt-decorator): changes from PR discussion

* fix(spec): plugin ordering

* Update schema.lua

---------

Co-authored-by: Jack Tysoe <jack@tys.one>
  • Loading branch information
tysoekong and ttyS0e committed Jan 23, 2024
1 parent d2cf328 commit e4031c3
Show file tree
Hide file tree
Showing 10 changed files with 502 additions and 1 deletion.
6 changes: 5 additions & 1 deletion .github/labeler.yml
Expand Up @@ -92,7 +92,11 @@ plugins/acme:

plugins/ai-proxy:
- changed-files:
- any-glob-to-any-file: ['kong/plugins/ai-proxy/**/*', 'kong/llm/**/*']
- any-glob-to-any-file: ['kong/plugins/ai-proxy/**/*', 'kong/llm/**/*']

plugins/ai-prompt-decorator:
- changed-files:
- any-glob-to-any-file: kong/plugins/ai-prompt-decorator/**/*

plugins/aws-lambda:
- changed-files:
Expand Down
3 changes: 3 additions & 0 deletions changelog/unreleased/kong/add-ai-prompt-decorator-plugin.yml
@@ -0,0 +1,3 @@
message: Introduced the new **AI Prompt Decorator** plugin that enables prepending and appending llm/v1/chat messages onto consumer LLM requests, for prompt tuning.
type: feature
scope: Plugin
3 changes: 3 additions & 0 deletions kong-3.6.0-0.rockspec
Expand Up @@ -574,6 +574,9 @@ build = {
["kong.llm.drivers.mistral"] = "kong/llm/drivers/mistral.lua",
["kong.llm.drivers.llama2"] = "kong/llm/drivers/llama2.lua",

["kong.plugins.ai-prompt-decorator.handler"] = "kong/plugins/ai-prompt-decorator/handler.lua",
["kong.plugins.ai-prompt-decorator.schema"] = "kong/plugins/ai-prompt-decorator/schema.lua",

["kong.vaults.env"] = "kong/vaults/env/init.lua",
["kong.vaults.env.schema"] = "kong/vaults/env/schema.lua",

Expand Down
1 change: 1 addition & 0 deletions kong/constants.lua
Expand Up @@ -37,6 +37,7 @@ local plugins = {
"zipkin",
"opentelemetry",
"ai-proxy",
"ai-prompt-decorator",
}

local plugin_map = {}
Expand Down
72 changes: 72 additions & 0 deletions kong/plugins/ai-prompt-decorator/handler.lua
@@ -0,0 +1,72 @@
local _M = {}

-- imports
local kong_meta = require "kong.meta"
local new_tab = require("table.new")
local EMPTY = {}
--

_M.PRIORITY = 772
_M.VERSION = kong_meta.version


local function bad_request(msg)
kong.log.debug(msg)
return kong.response.exit(400, { error = { message = msg } })
end

function _M.execute(request, conf)
local prepend = conf.prompts.prepend or EMPTY
local append = conf.prompts.append or EMPTY

if #prepend == 0 and #append == 0 then
return request, nil
end

local old_messages = request.messages
local new_messages = new_tab(#append + #prepend + #old_messages, 0)
request.messages = new_messages

local n = 0

for _, msg in ipairs(prepend) do
n = n + 1
new_messages[n] = { role = msg.role, content = msg.content }
end

for _, msg in ipairs(old_messages) do
n = n + 1
new_messages[n] = msg
end

for _, msg in ipairs(append) do
n = n + 1
new_messages[n] = { role = msg.role, content = msg.content }
end

return request, nil
end

function _M:access(conf)
kong.service.request.enable_buffering()
kong.ctx.shared.ai_prompt_decorated = true -- future use

-- if plugin ordering was altered, receive the "decorated" request
local request, err = kong.request.get_body("application/json")
if err then
return bad_request("this LLM route only supports application/json requests")
end

if not request.messages or #request.messages < 1 then
return bad_request("this LLM route only supports llm/chat type requests")
end

local decorated_request, err = self.execute(request, conf)
if err then
return bad_request(err)
end

kong.service.request.set_body(decorated_request, "application/json")
end

return _M
50 changes: 50 additions & 0 deletions kong/plugins/ai-prompt-decorator/schema.lua
@@ -0,0 +1,50 @@
local typedefs = require "kong.db.schema.typedefs"

local prompt_record = {
type = "record",
required = false,
fields = {
{ role = { type = "string", required = true, one_of = { "system", "assistant", "user" }, default = "system" }},
{ content = { type = "string", required = true, len_min = 1, len_max = 500 } },
}
}

local prompts_record = {
type = "record",
required = false,
fields = {
{ prepend = {
type = "array",
description = "Insert chat messages at the beginning of the chat message array. "
.. "This array preserves exact order when adding messages.",
elements = prompt_record,
required = false,
len_max = 15,
}},
{ append = {
type = "array",
description = "Insert chat messages at the end of the chat message array. "
.. "This array preserves exact order when adding messages.",
elements = prompt_record,
required = false,
len_max = 15,
}},
}
}

return {
name = "ai-prompt-decorator",
fields = {
{ protocols = typedefs.protocols_http },
{ config = {
type = "record",
fields = {
{ prompts = prompts_record }
}
}
}
},
entity_checks = {
{ at_least_one_of = { "config.prompts.prepend", "config.prompts.append" } },
},
}
1 change: 1 addition & 0 deletions spec/01-unit/12-plugins_order_spec.lua
Expand Up @@ -72,6 +72,7 @@ describe("Plugins", function()
"response-ratelimiting",
"request-transformer",
"response-transformer",
"ai-prompt-decorator",
"ai-proxy",
"aws-lambda",
"azure-functions",
Expand Down
90 changes: 90 additions & 0 deletions spec/03-plugins/41-ai-prompt-decorator/00-config_spec.lua
@@ -0,0 +1,90 @@
local PLUGIN_NAME = "ai-prompt-decorator"


-- helper function to validate data against a schema
local validate do
local validate_entity = require("spec.helpers").validate_plugin_config_schema
local plugin_schema = require("kong.plugins."..PLUGIN_NAME..".schema")

function validate(data)
return validate_entity(data, plugin_schema)
end
end

describe(PLUGIN_NAME .. ": (schema)", function()
it("won't allow empty config object", function()
local config = {
}

local ok, err = validate(config)

assert.is_falsy(ok)
assert.not_nil(err)
assert.equal("at least one of these fields must be non-empty: 'config.prompts.prepend', 'config.prompts.append'", err["@entity"][1])
end)

it("won't allow both head and tail to be unset", function()
local config = {
prompts = {},
}

local ok, err = validate(config)

assert.is_falsy(ok)
assert.not_nil(err)
assert.equal("at least one of these fields must be non-empty: 'config.prompts.prepend', 'config.prompts.append'", err["@entity"][1])
end)

it("won't allow both allow_patterns and deny_patterns to be empty arrays", function()
local config = {
prompts = {
prepend = {},
append = {},
},
}

local ok, err = validate(config)

assert.is_falsy(ok)
assert.not_nil(err)
assert.equal("at least one of these fields must be non-empty: 'config.prompts.prepend', 'config.prompts.append'", err["@entity"][1])
end)

it("allows prepend only", function()
local config = {
prompts = {
prepend = {
[1] = {
role = "system",
content = "Prepend text 1 here.",
},
},
append = {},
},
}

local ok, err = validate(config)

assert.is_truthy(ok)
assert.is_nil(err)
end)

it("allows append only", function()
local config = {
prompts = {
prepend = {},
append = {
[1] = {
role = "system",
content = "Prepend text 1 here.",
},
},
},
}

local ok, err = validate(config)

assert.is_truthy(ok)
assert.is_nil(err)
end)
end)

0 comments on commit e4031c3

Please sign in to comment.