Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 104 additions & 2 deletions apisix/plugins/ai-rate-limiting.lua
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ local require = require
local setmetatable = setmetatable
local ipairs = ipairs
local type = type
local pairs = pairs
local pcall = pcall
local load = load
local math_floor = math.floor
local math_huge = math.huge
local core = require("apisix.core")
local limit_count = require("apisix.plugins.limit-count.init")

Expand Down Expand Up @@ -61,10 +66,19 @@ local schema = {
show_limit_quota_header = {type = "boolean", default = true},
limit_strategy = {
type = "string",
enum = {"total_tokens", "prompt_tokens", "completion_tokens"},
enum = {"total_tokens", "prompt_tokens", "completion_tokens", "expression"},
default = "total_tokens",
description = "The strategy to limit the tokens"
},
cost_expr = {
type = "string",
minLength = 1,
description = "Lua arithmetic expression for dynamic token cost calculation. "
.. "Variables are injected from the LLM API raw usage response fields. "
.. "Missing variables default to 0. "
.. "Only valid when limit_strategy is 'expression'. "
.. "Example: input_tokens + cache_creation_input_tokens + output_tokens",
},
instances = {
type = "array",
items = instance_limit_schema,
Expand Down Expand Up @@ -136,8 +150,42 @@ local limit_conf_cache = core.lrucache.new({
})


-- safe math functions allowed in cost expressions
local expr_safe_env = {
math = math,
abs = math.abs,
ceil = math.ceil,
floor = math.floor,
max = math.max,
min = math.min,
}

local function compile_cost_expr(expr_str)
local fn_code = "return " .. expr_str
-- validate syntax by loading first
local fn, err = load(fn_code, "cost_expr", "t", expr_safe_env)
if not fn then
return nil, err
end
return fn_code
end


function _M.check_schema(conf)
return core.schema.check(schema, conf)
local ok, err = core.schema.check(schema, conf)
if not ok then
return false, err
end
if conf.limit_strategy == "expression" then
if not conf.cost_expr or conf.cost_expr == "" then
return false, "cost_expr is required when limit_strategy is 'expression'"
end
local _, compile_err = compile_cost_expr(conf.cost_expr)
if compile_err then
return false, "invalid cost_expr: " .. compile_err
end
end
return true
end


Expand Down Expand Up @@ -264,7 +312,57 @@ function _M.check_instance_status(conf, ctx, instance_name)
end


local function eval_cost_expr(conf_cost_expr, raw)
local fn_code = "return " .. conf_cost_expr
-- build environment: safe math + usage variables (missing vars default to 0)
local env = setmetatable({}, {
__index = function(_, k)
local v = expr_safe_env[k]
if v ~= nil then
return v
end
return 0
end
})
for k, v in pairs(raw) do
if type(v) == "number" and not expr_safe_env[k] then
env[k] = v
end
end
local fn, err = load(fn_code, "cost_expr", "t", env)
if not fn then
return nil, "failed to compile cost_expr: " .. err
end
local ok, result = pcall(fn)
if not ok then
return nil, "failed to evaluate cost_expr: " .. result
end
if type(result) ~= "number" then
return nil, "cost_expr must return a number, got: " .. type(result)
end
if result ~= result or result == math_huge or result == -math_huge then
return nil, "cost_expr returned non-finite value"
end
if result < 0 then
result = 0
end
return math_floor(result + 0.5)
end

local function get_token_usage(conf, ctx)
if conf.limit_strategy == "expression" then
local raw = ctx.llm_raw_usage
if not raw then
return
end
local result, err = eval_cost_expr(conf.cost_expr, raw)
if not result then
core.log.error(err)
return
end
return result
end

local usage = ctx.ai_token_usage
if not usage then
return
Expand All @@ -288,6 +386,10 @@ function _M.log(conf, ctx)
core.log.error("failed to get token usage for llm service")
return
end
if used_tokens == 0 then
core.log.info("token usage is 0, skip rate limiting")
return
end

core.log.info("instance name: ", instance_name, " used tokens: ", used_tokens)

Expand Down
Loading
Loading