From def201f566ccf2dd9b670e2f38e401a0450b1cb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Garc=C3=ADa=20Cota?= Date: Tue, 28 Nov 2017 12:31:43 +0100 Subject: [PATCH] feat(plugins) convert plugins to routes & services --- tests(statsd) make statsd plugin work with routes & services --- feat(galileo, ldap) stop using ctx.api Plugins still using it: oauth2, rate limiting, response rate limiting --- use service_id & route_id in rate-limiting plugin --- feat(response-rate-limiting) use route_id & service_id in rrl --- feat(oauth2) use service_id in oauth2 plugin --- kong/core/plugins_iterator.lua | 6 +- kong/db/strategies/cassandra/services.lua | 69 ++- kong/plugins/galileo/handler.lua | 8 +- kong/plugins/ldap-auth/access.lua | 2 +- kong/plugins/oauth2/access.lua | 48 +- kong/plugins/oauth2/daos.lua | 35 +- kong/plugins/oauth2/migrations/cassandra.lua | 16 +- kong/plugins/oauth2/migrations/postgres.lua | 15 +- kong/plugins/rate-limiting/handler.lua | 13 +- .../rate-limiting/migrations/cassandra.lua | 19 +- .../rate-limiting/migrations/postgres.lua | 40 +- .../rate-limiting/policies/cluster.lua | 54 +- kong/plugins/rate-limiting/policies/init.lua | 52 +- kong/plugins/response-ratelimiting/access.lua | 7 +- .../plugins/response-ratelimiting/handler.lua | 8 +- kong/plugins/response-ratelimiting/log.lua | 8 +- .../migrations/cassandra.lua | 17 +- .../migrations/postgres.lua | 43 +- .../policies/cluster.lua | 56 +- .../response-ratelimiting/policies/init.lua | 57 +- kong/plugins/statsd/handler.lua | 19 +- spec/03-plugins/06-statsd/01-log_spec.lua | 238 +++----- .../24-rate-limiting/02-policies_spec.lua | 16 +- .../24-rate-limiting/04-access_spec.lua | 478 +++++++++------ .../02-policies_spec.lua | 16 +- .../04-access_spec.lua | 556 ++++++++++-------- spec/03-plugins/26-oauth2/01-schema_spec.lua | 251 ++++++-- spec/03-plugins/26-oauth2/02-api_spec.lua | 38 +- spec/03-plugins/26-oauth2/03-access_spec.lua | 137 +++-- .../26-oauth2/04-invalidations_spec.lua | 7 +- spec/fixtures/blueprints.lua | 22 +- 31 files changed, 1450 insertions(+), 901 deletions(-) diff --git a/kong/core/plugins_iterator.lua b/kong/core/plugins_iterator.lua index d637a67461c..c9ce0795382 100644 --- a/kong/core/plugins_iterator.lua +++ b/kong/core/plugins_iterator.lua @@ -65,7 +65,11 @@ local function load_plugin_configuration(route_id, return responses.send_HTTP_INTERNAL_SERVER_ERROR(err) end if plugin ~= nil and plugin.enabled then - return plugin.config or {} + local cfg = plugin.config or {} + cfg.route_id = plugin.route_id + cfg.service_id = plugin.service_id + cfg.consumer_id = plugin.consumer_id + return cfg end end diff --git a/kong/db/strategies/cassandra/services.lua b/kong/db/strategies/cassandra/services.lua index 29a4f65afef..d7c5178e379 100644 --- a/kong/db/strategies/cassandra/services.lua +++ b/kong/db/strategies/cassandra/services.lua @@ -1,44 +1,48 @@ local cassandra = require "cassandra" -local _Services = {} - +local fmt = string.format -function _Services:delete(primary_key) - local ok, err_t = self.super.delete(self, primary_key) - if not ok then - return nil, err_t - end - local plugins = {} - local connector = self.connector - local cluster = connector.cluster +local _Services = {} - -- retrieve plugins associated with this Service - local query = "SELECT * FROM plugins WHERE service_id = ? ALLOW FILTERING" - local args = { cassandra.uuid(primary_key.id) } +local function select_by_service_id(cluster, table_name, service_id, errors) + local select_q = fmt("SELECT * FROM %s WHERE service_id = ?", + table_name) + local res = {} + local count = 0 - for rows, err in cluster:iterate(query, args) do + for rows, err in cluster:iterate(select_q, { cassandra.uuid(service_id) }) do if err then - return nil, self.errors:database_error("could not fetch plugins " .. - "for Service: " .. err) + return nil, + errors:database_error( + fmt("could not fetch %s for Service: %s", table_name, err)) end for i = 1, #rows do - table.insert(plugins, rows[i]) + count = count + 1 + res[count] = rows[i] end end - -- CASCADE delete associated plugins + return res +end + +local function delete_cascade(connector, table_name, service_id, errors) + local entities = select_by_service_id(connector.cluster, table_name, service_id, errors) + + for i = 1, #entities do + local delete_q = fmt("DELETE from %s WHERE id = ?", table_name) - for i = 1, #plugins do - local res, err = connector:query("DELETE FROM plugins WHERE id = ?", { - cassandra.uuid(plugins[i].id) + local res, err = connector:query(delete_q, { + cassandra.uuid(entities[i].id) }, nil, "write") + if not res then - return nil, self.errors:database_error("could not delete plugin " .. - "associated with Service: " .. err) + return nil, errors:database_error( + fmt("could not delete instance of %s associated with Service: %s", + table_name, err)) end end @@ -46,4 +50,23 @@ function _Services:delete(primary_key) end +function _Services:delete(primary_key) + local ok, err_t = self.super.delete(self, primary_key) + if not ok then + return nil, err_t + end + + local connector = self.connector + local service_id = primary_key.id + local errors = self.errors + + local ok1, err1 = delete_cascade(connector, "plugins", service_id, errors) + local ok2, err2 = delete_cascade(connector, "oauth2_tokens", service_id, errors) + local ok3, err3 = delete_cascade(connector, "oauth2_authorization_codes", service_id, errors) + + return ok1 and ok2 and ok3, + err1 or err2 or err3 +end + + return _Services diff --git a/kong/plugins/galileo/handler.lua b/kong/plugins/galileo/handler.lua index 731a13231d7..27897e5d593 100644 --- a/kong/plugins/galileo/handler.lua +++ b/kong/plugins/galileo/handler.lua @@ -10,7 +10,7 @@ local Buffer = require "kong.plugins.galileo.buffer" local read_body = ngx.req.read_body local get_body_data = ngx.req.get_body_data -local _alf_buffers = {} -- buffers per-api +local _alf_buffers = {} -- buffers per-route local _server_addr local GalileoHandler = BasePlugin:extend() @@ -51,9 +51,9 @@ function GalileoHandler:log(conf) GalileoHandler.super.log(self) local ctx = ngx.ctx - local api_id = ctx.api.id + local route_id = ctx.route.id - local buf = _alf_buffers[api_id] + local buf = _alf_buffers[route_id] if not buf then local err conf.server_addr = _server_addr @@ -62,7 +62,7 @@ function GalileoHandler:log(conf) ngx.log(ngx.ERR, "could not create ALF buffer: ", err) return end - _alf_buffers[api_id] = buf + _alf_buffers[route_id] = buf end local req_body, res_body diff --git a/kong/plugins/ldap-auth/access.lua b/kong/plugins/ldap-auth/access.lua index 79e65211640..aa16df140e7 100644 --- a/kong/plugins/ldap-auth/access.lua +++ b/kong/plugins/ldap-auth/access.lua @@ -90,7 +90,7 @@ local function authenticate(conf, given_credentials) return false end - local cache_key = "ldap_auth_cache:" .. ngx.ctx.api.id .. ":" .. given_username + local cache_key = "ldap_auth_cache:" .. ngx.ctx.route.id .. ":" .. given_username local credential, err = singletons.cache:get(cache_key, { ttl = conf.cache_ttl, neg_ttl = conf.cache_ttl, diff --git a/kong/plugins/oauth2/access.lua b/kong/plugins/oauth2/access.lua index b1f83660012..8a22b5dd907 100644 --- a/kong/plugins/oauth2/access.lua +++ b/kong/plugins/oauth2/access.lua @@ -36,7 +36,7 @@ local GRANT_PASSWORD = "password" local ERROR = "error" local AUTHENTICATED_USERID = "authenticated_userid" -local function generate_token(conf, api, credential, authenticated_userid, scope, state, expiration, disable_refresh) +local function generate_token(conf, service, credential, authenticated_userid, scope, state, expiration, disable_refresh) local token_expiration = expiration or conf.token_expiration local refresh_token @@ -49,12 +49,12 @@ local function generate_token(conf, api, credential, authenticated_userid, scope refresh_token_ttl = conf.refresh_token_ttl end - local api_id + local service_id if not conf.global_credentials then - api_id = api.id + service_id = service.id end local token, err = singletons.dao.oauth2_tokens:insert({ - api_id = api_id, + service_id = service_id, credential_id = credential.id, authenticated_userid = authenticated_userid, expires_in = token_expiration, @@ -180,12 +180,12 @@ local function authorize(conf) -- If there are no errors, keep processing the request if not response_params[ERROR] then if response_type == CODE then - local api_id + local service_id if not conf.global_credentials then - api_id = ngx.ctx.api.id + service_id = ngx.ctx.service.id end local authorization_code, err = singletons.dao.oauth2_authorization_codes:insert({ - api_id = api_id, + service_id = service_id, credential_id = client.id, authenticated_userid = parameters[AUTHENTICATED_USERID], scope = table.concat(scopes, " ") @@ -200,7 +200,7 @@ local function authorize(conf) } else -- Implicit grant, override expiration to zero - response_params = generate_token(conf, ngx.ctx.api, client, parameters[AUTHENTICATED_USERID], table.concat(scopes, " "), state, nil, true) + response_params = generate_token(conf, ngx.ctx.service, client, parameters[AUTHENTICATED_USERID], table.concat(scopes, " "), state, nil, true) is_implicit_grant = true end end @@ -312,17 +312,17 @@ local function issue_token(conf) if not response_params[ERROR] then if grant_type == GRANT_AUTHORIZATION_CODE then local code = parameters[CODE] - local api_id + local service_id if not conf.global_credentials then - api_id = ngx.ctx.api.id + service_id = ngx.ctx.service.id end - local authorization_code = code and singletons.dao.oauth2_authorization_codes:find_all({api_id = api_id, code = code})[1] + local authorization_code = code and singletons.dao.oauth2_authorization_codes:find_all({service_id = service_id, code = code})[1] if not authorization_code then response_params = {[ERROR] = "invalid_request", error_description = "Invalid " .. CODE} elseif authorization_code.credential_id ~= client.id then response_params = {[ERROR] = "invalid_request", error_description = "Invalid " .. CODE} else - response_params = generate_token(conf, ngx.ctx.api, client, authorization_code.authenticated_userid, authorization_code.scope, state) + response_params = generate_token(conf, ngx.ctx.service, client, authorization_code.authenticated_userid, authorization_code.scope, state) singletons.dao.oauth2_authorization_codes:delete({id=authorization_code.id}) -- Delete authorization code so it cannot be reused end elseif grant_type == GRANT_CLIENT_CREDENTIALS then @@ -335,7 +335,7 @@ local function issue_token(conf) if not ok then response_params = scopes -- If it's not ok, then this is the error message else - response_params = generate_token(conf, ngx.ctx.api, client, parameters.authenticated_userid, table.concat(scopes, " "), state, nil, true) + response_params = generate_token(conf, ngx.ctx.service, client, parameters.authenticated_userid, table.concat(scopes, " "), state, nil, true) end end elseif grant_type == GRANT_PASSWORD then @@ -350,16 +350,16 @@ local function issue_token(conf) if not ok then response_params = scopes -- If it's not ok, then this is the error message else - response_params = generate_token(conf, ngx.ctx.api, client, parameters.authenticated_userid, table.concat(scopes, " "), state) + response_params = generate_token(conf, ngx.ctx.service, client, parameters.authenticated_userid, table.concat(scopes, " "), state) end end elseif grant_type == GRANT_REFRESH_TOKEN then local refresh_token = parameters[REFRESH_TOKEN] - local api_id + local service_id if not conf.global_credentials then - api_id = ngx.ctx.api.id + service_id = ngx.ctx.service.id end - local token = refresh_token and singletons.dao.oauth2_tokens:find_all({api_id = api_id, refresh_token = refresh_token})[1] + local token = refresh_token and singletons.dao.oauth2_tokens:find_all({service_id = service_id, refresh_token = refresh_token})[1] if not token then response_params = {[ERROR] = "invalid_request", error_description = "Invalid " .. REFRESH_TOKEN} else @@ -367,7 +367,7 @@ local function issue_token(conf) if token.credential_id ~= client.id then response_params = {[ERROR] = "invalid_client", error_description = "Invalid client authentication"} else - response_params = generate_token(conf, ngx.ctx.api, client, token.authenticated_userid, token.scope, state) + response_params = generate_token(conf, ngx.ctx.service, client, token.authenticated_userid, token.scope, state) singletons.dao.oauth2_tokens:delete({id=token.id}) -- Delete old token end end @@ -387,12 +387,12 @@ local function issue_token(conf) }) end -local function load_token_into_memory(conf, api, access_token) - local api_id +local function load_token_into_memory(conf, service, access_token) + local service_id if not conf.global_credentials then - api_id = api.id + service_id = service.id end - local credentials, err = singletons.dao.oauth2_tokens:find_all { api_id = api_id, access_token = access_token } + local credentials, err = singletons.dao.oauth2_tokens:find_all { service_id = service_id, access_token = access_token } local result if err then return nil, err @@ -407,7 +407,7 @@ local function retrieve_token(conf, access_token) if access_token then local token_cache_key = singletons.dao.oauth2_tokens:cache_key(access_token) token, err = singletons.cache:get(token_cache_key, nil, - load_token_into_memory, conf, ngx.ctx.api, + load_token_into_memory, conf, ngx.ctx.service, access_token) if err then return responses.send_HTTP_INTERNAL_SERVER_ERROR(err) @@ -506,7 +506,7 @@ local function do_authentication(conf) return false, {status = 401, message = {[ERROR] = "invalid_token", error_description = "The access token is invalid or has expired"}, headers = {["WWW-Authenticate"] = 'Bearer realm="service" error="invalid_token" error_description="The access token is invalid or has expired"'}} end - if (token.api_id and ngx.ctx.api.id ~= token.api_id) or (token.api_id == nil and not conf.global_credentials) then + if (token.service_id and ngx.ctx.service.id ~= token.service_id) or (token.service_id == nil and not conf.global_credentials) then return false, {status = 401, message = {[ERROR] = "invalid_token", error_description = "The access token is invalid or has expired"}, headers = {["WWW-Authenticate"] = 'Bearer realm="service" error="invalid_token" error_description="The access token is invalid or has expired"'}} end diff --git a/kong/plugins/oauth2/daos.lua b/kong/plugins/oauth2/daos.lua index 1a6eb26d3cf..1d90000443d 100644 --- a/kong/plugins/oauth2/daos.lua +++ b/kong/plugins/oauth2/daos.lua @@ -1,5 +1,8 @@ local utils = require "kong.tools.utils" local url = require "socket.url" +local Errors = require "kong.dao.errors" +local db_errors = require "kong.db.errors" + local function validate_uris(v, t, column) if v then @@ -19,6 +22,28 @@ local function validate_uris(v, t, column) return true, nil end + +local function validate_service_id(service_id, db) + + if service_id ~= nil then + local service, err, err_t = db.services:select({ + id = service_id, + }) + if err then + if err_t.code == db_errors.codes.DATABASE_ERROR then + return false, Errors.db(err) + end + + return false, Errors.schema(err_t) + end + + if not service then + return false, Errors.foreign("no such Service (id=" .. service_id .. ")") + end + end +end + + local OAUTH2_CREDENTIALS_SCHEMA = { primary_key = {"id"}, table = "oauth2_credentials", @@ -39,13 +64,17 @@ local OAUTH2_AUTHORIZATION_CODES_SCHEMA = { table = "oauth2_authorization_codes", fields = { id = { type = "id", dao_insert_value = true }, + service_id = { type = "id" }, --foreign = "services:id" -- manually tested in self_check api_id = { type = "id", required = false, foreign = "apis:id" }, credential_id = { type = "id", required = true, foreign = "oauth2_credentials:id" }, code = { type = "string", required = false, unique = true, immutable = true, default = utils.random_string }, authenticated_userid = { type = "string", required = false }, scope = { type = "string" }, created_at = { type = "timestamp", immutable = true, dao_insert_value = true } - } + }, + self_check = function(self, auth_t, dao, is_update) + return validate_service_id(auth_t.service_id, dao.db.new_db) + end, } local BEARER = "bearer" @@ -55,6 +84,7 @@ local OAUTH2_TOKENS_SCHEMA = { cache_key = { "access_token" }, fields = { id = { type = "id", dao_insert_value = true }, + service_id = { type = "id" }, --foreign = "services:id" -- manually tested in self_check api_id = { type = "id", required = false, foreign = "apis:id" }, credential_id = { type = "id", required = true, foreign = "oauth2_credentials:id" }, token_type = { type = "string", required = true, enum = { BEARER }, default = BEARER }, @@ -65,6 +95,9 @@ local OAUTH2_TOKENS_SCHEMA = { scope = { type = "string" }, created_at = { type = "timestamp", immutable = true, dao_insert_value = true } }, + self_check = function(self, token_t, dao, is_update) + return validate_service_id(token_t.service_id, dao.db.new_db) + end, } return { diff --git a/kong/plugins/oauth2/migrations/cassandra.lua b/kong/plugins/oauth2/migrations/cassandra.lua index 1f379415871..edced1f208f 100644 --- a/kong/plugins/oauth2/migrations/cassandra.lua +++ b/kong/plugins/oauth2/migrations/cassandra.lua @@ -189,5 +189,19 @@ return { end end, down = function(_, _, dao) end -- not implemented - } + }, + { + name = "2018-01-09-oauth2_c_add_service_id", + up = [[ + ALTER TABLE oauth2_authorization_codes ADD service_id uuid; + CREATE INDEX IF NOT EXISTS ON oauth2_authorization_codes(service_id); + + ALTER TABLE oauth2_tokens ADD service_id uuid; + CREATE INDEX IF NOT EXISTS ON oauth2_tokens(service_id); + ]], + down = [[ + ALTER TABLE oauth2_authorization_codes DROP service_id; + ALTER TABLE oauth2_tokens DROP service_id; + ]], + }, } diff --git a/kong/plugins/oauth2/migrations/postgres.lua b/kong/plugins/oauth2/migrations/postgres.lua index 885e7d93130..7614d82c710 100644 --- a/kong/plugins/oauth2/migrations/postgres.lua +++ b/kong/plugins/oauth2/migrations/postgres.lua @@ -202,5 +202,18 @@ return { end end, down = function(_, _, dao) end -- not implemented - } + }, + { + name = "2018-01-09-oauth2_pg_add_service_id", + up = [[ + ALTER TABLE oauth2_authorization_codes ADD COLUMN service_id uuid + REFERENCES services (id) ON DELETE CASCADE; + ALTER TABLE oauth2_tokens ADD COLUMN service_id uuid + REFERENCES services ON DELETE CASCADE; + ]], + down = [[ + ALTER TABLE oauth2_tokens DROP COLUMN service_id; + ALTER TABLE oauth2_authorization_codes DROP COLUMN service_id; + ]], + }, } diff --git a/kong/plugins/rate-limiting/handler.lua b/kong/plugins/rate-limiting/handler.lua index bef89c7381c..ba2d959e2f0 100644 --- a/kong/plugins/rate-limiting/handler.lua +++ b/kong/plugins/rate-limiting/handler.lua @@ -38,12 +38,12 @@ local function get_identifier(conf) return identifier end -local function get_usage(conf, api_id, identifier, current_timestamp, limits) +local function get_usage(conf, identifier, current_timestamp, limits) local usage = {} local stop for name, limit in pairs(limits) do - local current_usage, err = policies[conf.policy].usage(conf, api_id, identifier, current_timestamp, name) + local current_usage, err = policies[conf.policy].usage(conf, identifier, current_timestamp, name) if err then return nil, nil, err end @@ -75,7 +75,6 @@ function RateLimitingHandler:access(conf) -- Consumer is identified by ip address or authenticated_credential id local identifier = get_identifier(conf) - local api_id = ngx.ctx.api.id local policy = conf.policy local fault_tolerant = conf.fault_tolerant @@ -89,7 +88,7 @@ function RateLimitingHandler:access(conf) year = conf.year } - local usage, stop, err = get_usage(conf, api_id, identifier, current_timestamp, limits) + local usage, stop, err = get_usage(conf, identifier, current_timestamp, limits) if err then if fault_tolerant then ngx_log(ngx.ERR, "failed to get usage: ", tostring(err)) @@ -113,15 +112,15 @@ function RateLimitingHandler:access(conf) end end - local incr = function(premature, conf, limits, api_id, identifier, current_timestamp, value) + local incr = function(premature, conf, limits, identifier, current_timestamp, value) if premature then return end - policies[policy].increment(conf, limits, api_id, identifier, current_timestamp, value) + policies[policy].increment(conf, limits, identifier, current_timestamp, value) end -- Increment metrics for configured periods if the request goes through - local ok, err = ngx_timer_at(0, incr, conf, limits, api_id, identifier, current_timestamp, 1) + local ok, err = ngx_timer_at(0, incr, conf, limits, identifier, current_timestamp, 1) if not ok then ngx_log(ngx.ERR, "failed to create timer: ", err) end diff --git a/kong/plugins/rate-limiting/migrations/cassandra.lua b/kong/plugins/rate-limiting/migrations/cassandra.lua index 53d68599fd8..3fbbc8d8367 100644 --- a/kong/plugins/rate-limiting/migrations/cassandra.lua +++ b/kong/plugins/rate-limiting/migrations/cassandra.lua @@ -54,5 +54,22 @@ return { end end end - } + }, + { + name = "2017-11-30-120000_add_route_and_service_id", + up = [[ + DROP TABLE ratelimiting_metrics; + CREATE TABLE ratelimiting_metrics( + route_id uuid, + service_id uuid, + identifier text, + period text, + period_date timestamp, + value counter, + PRIMARY KEY ((route_id, service_id, identifier, period_date, period)) + ); + ]], + down = nil, + }, + } diff --git a/kong/plugins/rate-limiting/migrations/postgres.lua b/kong/plugins/rate-limiting/migrations/postgres.lua index 7861c734b48..033bd8aeef9 100644 --- a/kong/plugins/rate-limiting/migrations/postgres.lua +++ b/kong/plugins/rate-limiting/migrations/postgres.lua @@ -72,5 +72,43 @@ return { end end end - } + }, + { + name = "2017-11-30-120000_add_route_and_service_id", + up = [[ + ALTER TABLE ratelimiting_metrics ADD COLUMN route_id uuid; + ALTER TABLE ratelimiting_metrics ADD COLUMN service_id uuid; + + CREATE OR REPLACE FUNCTION increment_rate_limits(r_id uuid, s_id uuid, i text, p text, p_date timestamp with time zone, v integer) RETURNS VOID AS $$ + BEGIN + LOOP + UPDATE ratelimiting_metrics + SET value = value + v + WHERE route_id = r_id + AND service_id = s_id + AND identifier = i + AND period = p + AND period_date = p_date; + IF found then RETURN; + END IF; + + BEGIN + INSERT INTO ratelimiting_metrics(route_id, service_id, period, period_date, identifier, value) + VALUES(r_id, s_id, p, p_date, i, v); + RETURN; + EXCEPTION WHEN unique_violation THEN + END; + END LOOP; + END; + $$ LANGUAGE 'plpgsql'; + ]], + down = nil, + }, + { + name = "2017-11-30-130000_remove_api_id", + up = [[ + ALTER TABLE ratelimiting_metrics DROP COLUMN api_id; + ]], + down = nil, + }, } diff --git a/kong/plugins/rate-limiting/policies/cluster.lua b/kong/plugins/rate-limiting/policies/cluster.lua index 9c209e70899..c528e327275 100644 --- a/kong/plugins/rate-limiting/policies/cluster.lua +++ b/kong/plugins/rate-limiting/policies/cluster.lua @@ -8,7 +8,7 @@ local ERR = ngx.ERR return { ["cassandra"] = { - increment = function(db, limits, api_id, identifier, current_timestamp, value) + increment = function(db, limits, route_id, service_id, identifier, current_timestamp, value) local periods = timestamp.get_timestamps(current_timestamp) for period, period_date in pairs(periods) do @@ -16,13 +16,15 @@ return { local res, err = db:query([[ UPDATE ratelimiting_metrics SET value = value + ? - WHERE api_id = ? AND - identifier = ? AND - period_date = ? AND - period = ? + WHERE route_id = ? + AND service_id = ? + AND identifier = ? + AND period_date = ? + AND period = ? ]], { db.cassandra.counter(value), - db.cassandra.uuid(api_id), + db.cassandra.uuid(route_id), + db.cassandra.uuid(service_id), identifier, db.cassandra.timestamp(period_date), period, @@ -36,18 +38,19 @@ return { return true end, - find = function(db, api_id, identifier, current_timestamp, period) + find = function(db, route_id, service_id, identifier, current_timestamp, period) local periods = timestamp.get_timestamps(current_timestamp) local rows, err = db:query([[ - SELECT * - FROM ratelimiting_metrics - WHERE api_id = ? AND - identifier = ? AND - period_date = ? AND - period = ? + SELECT * FROM ratelimiting_metrics + WHERE route_id = ? + AND service_id = ? + AND identifier = ? + AND period_date = ? + AND period = ? ]], { - db.cassandra.uuid(api_id), + db.cassandra.uuid(route_id), + db.cassandra.uuid(service_id), identifier, db.cassandra.timestamp(periods[period]), period, @@ -58,16 +61,16 @@ return { end, }, ["postgres"] = { - increment = function(db, limits, api_id, identifier, current_timestamp, value) + increment = function(db, limits, route_id, service_id, identifier, current_timestamp, value) local buf = {} local periods = timestamp.get_timestamps(current_timestamp) for period, period_date in pairs(periods) do if limits[period] then - buf[#buf+1] = fmt([[ - SELECT increment_rate_limits('%s', '%s', '%s', to_timestamp('%s') - at time zone 'UTC', %d) - ]], api_id, identifier, period, period_date/1000, value) + buf[#buf + 1] = fmt([[ + SELECT increment_rate_limits('%s', '%s', '%s', '%s', + to_timestamp('%s') at time zone 'UTC', %d) + ]], route_id, service_id, identifier, period, period_date/1000, value) end end @@ -78,17 +81,18 @@ return { return true end, - find = function(db, api_id, identifier, current_timestamp, period) + find = function(db, route_id, service_id, identifier, current_timestamp, period) local periods = timestamp.get_timestamps(current_timestamp) local q = fmt([[ SELECT *, extract(epoch from period_date)*1000 AS period_date FROM ratelimiting_metrics - WHERE api_id = '%s' AND - identifier = '%s' AND - period_date = to_timestamp('%s') at time zone 'UTC' AND - period = '%s' - ]], api_id, identifier, periods[period]/1000, period) + WHERE route_id = '%s' + AND service_id = '%s' + AND identifier = '%s' + AND period_date = to_timestamp('%s') at time zone 'UTC' + AND period = '%s' + ]], route_id, service_id, identifier, periods[period]/1000, period) local res, err = db:query(q) if not res or err then diff --git a/kong/plugins/rate-limiting/policies/init.lua b/kong/plugins/rate-limiting/policies/init.lua index 86bbd0dd79b..952f90bba1b 100644 --- a/kong/plugins/rate-limiting/policies/init.lua +++ b/kong/plugins/rate-limiting/policies/init.lua @@ -9,8 +9,25 @@ local shm = ngx.shared.kong_cache local pairs = pairs local fmt = string.format -local get_local_key = function(api_id, identifier, period_date, name) - return fmt("ratelimit:%s:%s:%s:%s", api_id, identifier, period_date, name) +local NULL_UUID = "00000000-0000-0000-0000-000000000000" + +local function get_ids(conf) + conf = conf or {} + local route_id = conf.route_id + if not route_id or route_id == ngx.null then + route_id = NULL_UUID + end + local service_id = conf.service_id + if not service_id or service_id == ngx.null then + service_id = NULL_UUID + end + return route_id, service_id +end + + +local get_local_key = function(conf, identifier, period_date, name) + local route_id, service_id = get_ids(conf) + return fmt("ratelimit:%s:%s:%s:%s:%s", route_id, service_id, identifier, period_date, name) end local EXPIRATIONS = { @@ -24,12 +41,11 @@ local EXPIRATIONS = { return { ["local"] = { - increment = function(conf, limits, api_id, identifier, current_timestamp, value) + increment = function(conf, limits, identifier, current_timestamp, value) local periods = timestamp.get_timestamps(current_timestamp) for period, period_date in pairs(periods) do if limits[period] then - local cache_key = get_local_key(api_id, identifier, period_date, period) - + local cache_key = get_local_key(conf, identifier, period_date, period) local newval, err = shm:incr(cache_key, value, 0) if not newval then ngx_log(ngx.ERR, "[rate-limiting] could not increment counter ", @@ -41,9 +57,9 @@ return { return true end, - usage = function(conf, api_id, identifier, current_timestamp, name) + usage = function(conf, identifier, current_timestamp, name) local periods = timestamp.get_timestamps(current_timestamp) - local cache_key = get_local_key(api_id, identifier, periods[name], name) + local cache_key = get_local_key(conf, identifier, periods[name], name) local current_metric, err = shm:get(cache_key) if err then return nil, err @@ -52,10 +68,11 @@ return { end }, ["cluster"] = { - increment = function(conf, limits, api_id, identifier, current_timestamp, value) + increment = function(conf, limits, identifier, current_timestamp, value) local db = singletons.dao.db - local ok, err = policy_cluster[db.name].increment(db, limits, api_id, identifier, - current_timestamp, value) + local route_id, service_id = get_ids(conf) + local ok, err = policy_cluster[db.name].increment(db, limits, route_id, service_id, + identifier, current_timestamp, value) if not ok then ngx_log(ngx.ERR, "[rate-limiting] cluster policy: could not increment ", db.name, " counter: ", err) @@ -63,10 +80,11 @@ return { return ok, err end, - usage = function(conf, api_id, identifier, current_timestamp, name) + usage = function(conf, identifier, current_timestamp, name) local db = singletons.dao.db - local row, err = policy_cluster[db.name].find(db, api_id, identifier, - current_timestamp, name) + local route_id, service_id = get_ids(conf) + local row, err = policy_cluster[db.name].find(db, route_id, service_id, + identifier, current_timestamp, name) if err then return nil, err end @@ -75,7 +93,7 @@ return { end }, ["redis"] = { - increment = function(conf, limits, api_id, identifier, current_timestamp, value) + increment = function(conf, limits, identifier, current_timestamp, value) local red = redis:new() red:set_timeout(conf.redis_timeout) local ok, err = red:connect(conf.redis_host, conf.redis_port) @@ -106,7 +124,7 @@ return { local periods = timestamp.get_timestamps(current_timestamp) for period, period_date in pairs(periods) do if limits[period] then - local cache_key = get_local_key(api_id, identifier, period_date, period) + local cache_key = get_local_key(conf, identifier, period_date, period) local exists, err = red:exists(cache_key) if err then ngx_log(ngx.ERR, "failed to query Redis: ", err) @@ -142,7 +160,7 @@ return { return true end, - usage = function(conf, api_id, identifier, current_timestamp, name) + usage = function(conf, identifier, current_timestamp, name) local red = redis:new() red:set_timeout(conf.redis_timeout) local ok, err = red:connect(conf.redis_host, conf.redis_port) @@ -170,7 +188,7 @@ return { reports.retrieve_redis_version(red) local periods = timestamp.get_timestamps(current_timestamp) - local cache_key = get_local_key(api_id, identifier, periods[name], name) + local cache_key = get_local_key(conf, identifier, periods[name], name) local current_metric, err = red:get(cache_key) if err then return nil, err diff --git a/kong/plugins/response-ratelimiting/access.lua b/kong/plugins/response-ratelimiting/access.lua index 078eabf2e33..a1cb747b614 100644 --- a/kong/plugins/response-ratelimiting/access.lua +++ b/kong/plugins/response-ratelimiting/access.lua @@ -29,12 +29,12 @@ local function get_identifier(conf) return identifier end -local function get_usage(conf, api_id, identifier, current_timestamp, limits) +local function get_usage(conf, identifier, current_timestamp, limits) local usage = {} for k, v in pairs(limits) do -- Iterate over limit names for lk, lv in pairs(v) do -- Iterare over periods - local current_usage, err = policies[conf.policy].usage(conf, api_id, identifier, current_timestamp, lk, k) + local current_usage, err = policies[conf.policy].usage(conf, identifier, current_timestamp, lk, k) if err then return nil, err end @@ -64,12 +64,11 @@ function _M.execute(conf) -- Load info local current_timestamp = timestamp.get_utc() ngx.ctx.current_timestamp = current_timestamp -- For later use - local api_id = ngx.ctx.api.id local identifier = get_identifier(conf) ngx.ctx.identifier = identifier -- For later use -- Load current metric for configured period - local usage, err = get_usage(conf, api_id, identifier, current_timestamp, conf.limits) + local usage, err = get_usage(conf, identifier, current_timestamp, conf.limits) if err then if conf.fault_tolerant then ngx.log(ngx.ERR, "failed to get usage: ", tostring(err)) diff --git a/kong/plugins/response-ratelimiting/handler.lua b/kong/plugins/response-ratelimiting/handler.lua index 1c2bad9a986..a16e94b715c 100644 --- a/kong/plugins/response-ratelimiting/handler.lua +++ b/kong/plugins/response-ratelimiting/handler.lua @@ -5,29 +5,35 @@ local access = require "kong.plugins.response-ratelimiting.access" local log = require "kong.plugins.response-ratelimiting.log" local header_filter = require "kong.plugins.response-ratelimiting.header_filter" + local ResponseRateLimitingHandler = BasePlugin:extend() + function ResponseRateLimitingHandler:new() ResponseRateLimitingHandler.super.new(self, "response-ratelimiting") end + function ResponseRateLimitingHandler:access(conf) ResponseRateLimitingHandler.super.access(self) access.execute(conf) end + function ResponseRateLimitingHandler:header_filter(conf) ResponseRateLimitingHandler.super.header_filter(self) header_filter.execute(conf) end + function ResponseRateLimitingHandler:log(conf) if not ngx.ctx.stop_log and ngx.ctx.usage then ResponseRateLimitingHandler.super.log(self) - log.execute(conf, ngx.ctx.api.id, ngx.ctx.identifier, ngx.ctx.current_timestamp, ngx.ctx.increments, ngx.ctx.usage) + log.execute(conf, ngx.ctx.identifier, ngx.ctx.current_timestamp, ngx.ctx.increments, ngx.ctx.usage) end end + ResponseRateLimitingHandler.PRIORITY = 900 ResponseRateLimitingHandler.VERSION = "0.1.0" diff --git a/kong/plugins/response-ratelimiting/log.lua b/kong/plugins/response-ratelimiting/log.lua index 1479e38fe2e..e0cf90c1d27 100644 --- a/kong/plugins/response-ratelimiting/log.lua +++ b/kong/plugins/response-ratelimiting/log.lua @@ -3,7 +3,7 @@ local pairs = pairs local _M = {} -local function log(premature, conf, api_id, identifier, current_timestamp, increments, usage) +local function log(premature, conf, identifier, current_timestamp, increments, usage) if premature then return end @@ -11,13 +11,13 @@ local function log(premature, conf, api_id, identifier, current_timestamp, incre -- Increment metrics for all periods if the request goes through for k, v in pairs(usage) do if increments[k] and increments[k] ~= 0 then - policies[conf.policy].increment(conf, api_id, identifier, current_timestamp, increments[k], k) + policies[conf.policy].increment(conf, identifier, current_timestamp, increments[k], k) end end end -function _M.execute(conf, api_id, identifier, current_timestamp, increments, usage) - local ok, err = ngx.timer.at(0, log, conf, api_id, identifier, current_timestamp, increments, usage) +function _M.execute(conf, identifier, current_timestamp, increments, usage) + local ok, err = ngx.timer.at(0, log, conf, identifier, current_timestamp, increments, usage) if not ok then ngx.log(ngx.ERR, "failed to create timer: ", err) end diff --git a/kong/plugins/response-ratelimiting/migrations/cassandra.lua b/kong/plugins/response-ratelimiting/migrations/cassandra.lua index 94f66e31c9a..fc1f7349905 100644 --- a/kong/plugins/response-ratelimiting/migrations/cassandra.lua +++ b/kong/plugins/response-ratelimiting/migrations/cassandra.lua @@ -55,5 +55,20 @@ return { end end end - } + }, { + name = "2017-12-19-120000_add_route_and_service_id_to_response_ratelimiting", + up = [[ + DROP TABLE response_ratelimiting_metrics; + CREATE TABLE response_ratelimiting_metrics( + route_id uuid, + service_id uuid, + identifier text, + period text, + period_date timestamp, + value counter, + PRIMARY KEY ((route_id, service_id, identifier, period_date, period)) + ); + ]], + down = nil, + }, } diff --git a/kong/plugins/response-ratelimiting/migrations/postgres.lua b/kong/plugins/response-ratelimiting/migrations/postgres.lua index 53cf9cbd026..dff56a34960 100644 --- a/kong/plugins/response-ratelimiting/migrations/postgres.lua +++ b/kong/plugins/response-ratelimiting/migrations/postgres.lua @@ -73,5 +73,46 @@ return { end end end - } + }, + { + name = "2017-12-19-120000_add_route_and_service_id_to_response_ratelimiting", + up = [[ + ALTER TABLE response_ratelimiting_metrics ADD COLUMN route_id uuid; + ALTER TABLE response_ratelimiting_metrics ADD COLUMN service_id uuid; + + CREATE OR REPLACE FUNCTION increment_response_rate_limits( + r_id uuid, s_id uuid, i text, p text, p_date timestamp with time zone, v integer) + RETURNS VOID AS $$ + BEGIN + LOOP + UPDATE response_ratelimiting_metrics + SET value = value + v + WHERE route_id = r_id + AND service_id = s_id + AND identifier = i + AND period = p + AND period_date = p_date; + IF found then RETURN; + END IF; + + BEGIN + INSERT INTO response_ratelimiting_metrics(route_id, service_id, period, + period_date, identifier, value) + VALUES(r_id, s_id, p, p_date, i, v); + RETURN; + EXCEPTION WHEN unique_violation THEN + END; + END LOOP; + END; + $$ LANGUAGE 'plpgsql'; + ]], + down = nil, + }, + { + name = "2017-12-19-130000_remove_api_id_from_response_ratelimiting", + up = [[ + ALTER TABLE response_ratelimiting_metrics DROP COLUMN api_id; + ]], + down = nil, + }, } diff --git a/kong/plugins/response-ratelimiting/policies/cluster.lua b/kong/plugins/response-ratelimiting/policies/cluster.lua index 734c12eb707..b61c0bff40c 100644 --- a/kong/plugins/response-ratelimiting/policies/cluster.lua +++ b/kong/plugins/response-ratelimiting/policies/cluster.lua @@ -6,26 +6,30 @@ local fmt = string.format local log = ngx.log local ERR = ngx.ERR + return { ["cassandra"] = { - increment = function(db, api_id, identifier, current_timestamp, value, name) + increment = function(db, route_id, service_id, identifier, current_timestamp, value, name) local periods = timestamp.get_timestamps(current_timestamp) for period, period_date in pairs(periods) do local res, err = db:query([[ UPDATE response_ratelimiting_metrics SET value = value + ? - WHERE api_id = ? AND - identifier = ? AND - period_date = ? AND - period = ? + WHERE route_id = ? + AND service_id = ? + AND identifier = ? + AND period_date = ? + AND period = ? ]], { db.cassandra.counter(value), - db.cassandra.uuid(api_id), + db.cassandra.uuid(route_id), + db.cassandra.uuid(service_id), identifier, db.cassandra.timestamp(period_date), - name .. "_" .. period, + name .. "_" .. period }) + if not res then log(ERR, "[response-ratelimiting] cluster policy: could not increment ", "cassandra counter for period '", period, "': ", err) @@ -34,36 +38,39 @@ return { return true end, - find = function(db, api_id, identifier, current_timestamp, period, name) + find = function(db, route_id, service_id, identifier, current_timestamp, period, name) local periods = timestamp.get_timestamps(current_timestamp) local rows, err = db:query([[ SELECT * FROM response_ratelimiting_metrics - WHERE api_id = ? AND - identifier = ? AND - period_date = ? AND - period = ? + WHERE route_id = ? + AND service_id = ? + AND identifier = ? + AND period_date = ? + AND period = ? ]], { - db.cassandra.uuid(api_id), + db.cassandra.uuid(route_id), + db.cassandra.uuid(service_id), identifier, db.cassandra.timestamp(periods[period]), name .. "_" .. period, }) + if not rows then return nil, err elseif #rows <= 1 then return rows[1] else return nil, "bad rows result" end end, }, ["postgres"] = { - increment = function(db, api_id, identifier, current_timestamp, value, name) + increment = function(db, route_id, service_id, identifier, current_timestamp, value, name) local buf = {} local periods = timestamp.get_timestamps(current_timestamp) for period, period_date in pairs(periods) do - buf[#buf+1] = fmt([[ - SELECT increment_response_rate_limits('%s', '%s', '%s_%s', to_timestamp('%s') - at time zone 'UTC', %d) - ]], api_id, identifier, name, period, period_date/1000, value) + buf[#buf + 1] = fmt([[ + SELECT increment_response_rate_limits('%s', '%s', '%s', '%s_%s', + to_timestamp('%s') at time zone 'UTC', %d) + ]], route_id, service_id, identifier, name, period, period_date/1000, value) end local res, err = db:query(concat(buf, ";")) @@ -73,17 +80,18 @@ return { return true end, - find = function(db, api_id, identifier, current_timestamp, period, name) + find = function(db, route_id, service_id, identifier, current_timestamp, period, name) local periods = timestamp.get_timestamps(current_timestamp) local q = fmt([[ SELECT *, extract(epoch from period_date)*1000 AS period_date FROM response_ratelimiting_metrics - WHERE api_id = '%s' AND - identifier = '%s' AND - period_date = to_timestamp('%s') at time zone 'UTC' AND - period = '%s_%s' - ]], api_id, identifier, periods[period]/1000, name, period) + WHERE route_id = '%s' + AND service_id = '%s' + AND identifier = '%s' + AND period_date = to_timestamp('%s') at time zone 'UTC' + AND period = '%s_%s' + ]], route_id, service_id, identifier, periods[period]/1000, name, period) local res, err = db:query(q) if not res or err then diff --git a/kong/plugins/response-ratelimiting/policies/init.lua b/kong/plugins/response-ratelimiting/policies/init.lua index d6ed13e7117..13d156b5ea8 100644 --- a/kong/plugins/response-ratelimiting/policies/init.lua +++ b/kong/plugins/response-ratelimiting/policies/init.lua @@ -9,10 +9,31 @@ local shm = ngx.shared.kong_cache local pairs = pairs local fmt = string.format -local get_local_key = function(api_id, identifier, period_date, name, period) - return fmt("response-ratelimit:%s:%s:%s:%s:%s", api_id, identifier, period_date, name, period) + +local NULL_UUID = "00000000-0000-0000-0000-000000000000" + + +local function get_ids(conf) + conf = conf or {} + local route_id = conf.route_id + if not route_id or route_id == ngx.null then + route_id = NULL_UUID + end + local service_id = conf.service_id + if not service_id or service_id == ngx.null then + service_id = NULL_UUID + end + return route_id, service_id end + +local get_local_key = function(conf, identifier, period_date, name, period) + local route_id, service_id = get_ids(conf) + return fmt("response-ratelimit:%s:%s:%s:%s:%s:%s", + route_id, service_id, identifier, period_date, name, period) +end + + local EXPIRATIONS = { second = 1, minute = 60, @@ -24,10 +45,10 @@ local EXPIRATIONS = { return { ["local"] = { - increment = function(conf, api_id, identifier, current_timestamp, value, name) + increment = function(conf, identifier, current_timestamp, value, name) local periods = timestamp.get_timestamps(current_timestamp) for period, period_date in pairs(periods) do - local cache_key = get_local_key(api_id, identifier, period_date, name, period) + local cache_key = get_local_key(conf, identifier, period_date, name, period) local newval, err = shm:incr(cache_key, value, 0) if not newval then @@ -39,9 +60,9 @@ return { return true end, - usage = function(conf, api_id, identifier, current_timestamp, period, name) + usage = function(conf, identifier, current_timestamp, period, name) local periods = timestamp.get_timestamps(current_timestamp) - local cache_key = get_local_key(api_id, identifier, periods[period], name, period) + local cache_key = get_local_key(conf, identifier, periods[period], name, period) local current_metric, err = shm:get(cache_key) if err then return nil, err @@ -50,11 +71,11 @@ return { end }, ["cluster"] = { - increment = function(conf, api_id, identifier, current_timestamp, value, name) + increment = function(conf, identifier, current_timestamp, value, name) local db = singletons.dao.db - local ok, err = policy_cluster[db.name].increment(db, api_id, identifier, - current_timestamp, value, - name) + local route_id, service_id = get_ids(conf) + local ok, err = policy_cluster[db.name].increment(db, route_id, service_id, identifier, + current_timestamp, value, name) if not ok then ngx_log(ngx.ERR, "[response-ratelimiting] cluster policy: could not increment ", db.name, " counter: ", err) @@ -62,11 +83,11 @@ return { return ok, err end, - usage = function(conf, api_id, identifier, current_timestamp, period, name) + usage = function(conf, identifier, current_timestamp, period, name) local db = singletons.dao.db - local rows, err = policy_cluster[db.name].find(db, api_id, identifier, - current_timestamp, period, - name) + local route_id, service_id = get_ids(conf) + local rows, err = policy_cluster[db.name].find(db, route_id, service_id, identifier, + current_timestamp, period, name) if err then return nil, err end @@ -75,7 +96,7 @@ return { end }, ["redis"] = { - increment = function(conf, api_id, identifier, current_timestamp, value, name) + increment = function(conf, identifier, current_timestamp, value, name) local red = redis:new() red:set_timeout(conf.redis_timeout) local ok, err = red:connect(conf.redis_host, conf.redis_port) @@ -105,7 +126,7 @@ return { local idx = 0 local periods = timestamp.get_timestamps(current_timestamp) for period, period_date in pairs(periods) do - local cache_key = get_local_key(api_id, identifier, period_date, name, period) + local cache_key = get_local_key(conf, identifier, period_date, name, period) local exists, err = red:exists(cache_key) if err then ngx_log(ngx.ERR, "failed to query Redis: ", err) @@ -140,7 +161,7 @@ return { return true end, - usage = function(conf, api_id, identifier, current_timestamp, period, name) + usage = function(conf, identifier, current_timestamp, period, name) local red = redis:new() red:set_timeout(conf.redis_timeout) local ok, err = red:connect(conf.redis_host, conf.redis_port) @@ -168,7 +189,7 @@ return { reports.retrieve_redis_version(red) local periods = timestamp.get_timestamps(current_timestamp) - local cache_key = get_local_key(api_id, identifier, periods[period], name, period) + local cache_key = get_local_key(conf, identifier, periods[period], name, period) local current_metric, err = red:get(cache_key) if err then return nil, err diff --git a/kong/plugins/statsd/handler.lua b/kong/plugins/statsd/handler.lua index fe7d5a3c0ef..5d4b0ae3c29 100644 --- a/kong/plugins/statsd/handler.lua +++ b/kong/plugins/statsd/handler.lua @@ -85,14 +85,17 @@ local function log(premature, conf, message) return end - local api_name = string_gsub(message.api.name, "%.", "_") + local name = string_gsub(message.service.name ~= ngx.null and + message.service.name or message.service.host, + "%.", "_") + local stat_name = { - request_size = api_name .. ".request.size", - response_size = api_name .. ".response.size", - latency = api_name .. ".latency", - upstream_latency = api_name .. ".upstream_latency", - kong_latency = api_name .. ".kong_latency", - request_count = api_name .. ".request.count", + request_size = name .. ".request.size", + response_size = name .. ".response.size", + latency = name .. ".latency", + upstream_latency = name .. ".upstream_latency", + kong_latency = name .. ".kong_latency", + request_count = name .. ".request.count", } local stat_value = { request_size = message.request.size, @@ -113,7 +116,7 @@ local function log(premature, conf, message) local metric = metrics[metric_config.name] if metric then - metric(api_name, message, metric_config, logger) + metric(name, message, metric_config, logger) else local stat_name = stat_name[metric_config.name] diff --git a/spec/03-plugins/06-statsd/01-log_spec.lua b/spec/03-plugins/06-statsd/01-log_spec.lua index a9eff8ef95f..08a259ebd03 100644 --- a/spec/03-plugins/06-statsd/01-log_spec.lua +++ b/spec/03-plugins/06-statsd/01-log_spec.lua @@ -1,11 +1,14 @@ local helpers = require "spec.helpers" +local fmt = string.format + + local UDP_PORT = 20000 for _, strategy in helpers.each_strategy() do - pending("Plugin: statsd (log) [#" .. strategy .. "]", function() + describe("Plugin: statsd (log) [#" .. strategy .. "]", function() local proxy_client setup(function() @@ -21,75 +24,32 @@ for _, strategy in helpers.each_strategy() do consumer_id = consumer.id, } - local route1 = bp.routes:insert { - hosts = { "logging1.com" }, - } - - local route2 = bp.routes:insert { - hosts = { "logging2.com" }, - } - - local route3 = bp.routes:insert { - hosts = { "logging3.com" }, - } - - local route4 = bp.routes:insert { - hosts = { "logging4.com" }, - } - - local route5 = bp.routes:insert { - hosts = { "logging5.com" }, - } - - local route6 = bp.routes:insert { - hosts = { "logging6.com" }, - } - - local route7 = bp.routes:insert { - hosts = { "logging7.com" }, - } - - local route8 = bp.routes:insert { - hosts = { "logging8.com" }, - } - - local route9 = bp.routes:insert { - hosts = { "logging9.com" }, - } - - local route10 = bp.routes:insert { - hosts = { "logging10.com" }, - } - - local route11 = bp.routes:insert { - hosts = { "logging11.com" }, - } - - local route12 = bp.routes:insert { - hosts = { "logging12.com" }, - } - - local route13 = bp.routes:insert { - hosts = { "logging13.com" }, - } + local routes = {} + for i = 1, 13 do + local service = bp.services:insert { + protocol = helpers.mock_upstream_protocol, + host = helpers.mock_upstream_host, + port = helpers.mock_upstream_port, + name = fmt("statsd%s", i) + } + routes[i] = bp.routes:insert { + hosts = { fmt("logging%d.com", i) }, + service = service + } + end - bp.plugins:insert { - name = "key-auth", - route_id = route1.id, - } + bp.key_auth_plugins:insert { route_id = routes[1].id } - bp.plugins:insert { - route_id = route1.id, - name = "statsd", + bp.statsd_plugins:insert { + route_id = routes[1].id, config = { host = "127.0.0.1", port = UDP_PORT, }, } - bp.plugins:insert { - route_id = route2.id, - name = "statsd", + bp.statsd_plugins:insert { + route_id = routes[2].id, config = { host = "127.0.0.1", port = UDP_PORT, @@ -102,9 +62,8 @@ for _, strategy in helpers.each_strategy() do }, } - bp.plugins:insert { - route_id = route3.id, - name = "statsd", + bp.statsd_plugins:insert { + route_id = routes[3].id, config = { host = "127.0.0.1", port = UDP_PORT, @@ -118,9 +77,8 @@ for _, strategy in helpers.each_strategy() do }, } - bp.plugins:insert { - route_id = route4.id, - name = "statsd", + bp.statsd_plugins:insert { + route_id = routes[4].id, config = { host = "127.0.0.1", port = UDP_PORT, @@ -133,9 +91,8 @@ for _, strategy in helpers.each_strategy() do }, } - bp.plugins:insert { - route_id = route5.id, - name = "statsd", + bp.statsd_plugins:insert { + route_id = routes[5].id, config = { host = "127.0.0.1", port = UDP_PORT, @@ -149,9 +106,8 @@ for _, strategy in helpers.each_strategy() do } } - bp.plugins:insert { - route_id = route6.id, - name = "statsd", + bp.statsd_plugins:insert { + route_id = routes[6].id, config = { host = "127.0.0.1", port = UDP_PORT, @@ -164,9 +120,8 @@ for _, strategy in helpers.each_strategy() do }, } - bp.plugins:insert { - route_id = route7.id, - name = "statsd", + bp.statsd_plugins:insert { + route_id = routes[7].id, config = { host = "127.0.0.1", port = UDP_PORT, @@ -179,9 +134,8 @@ for _, strategy in helpers.each_strategy() do }, } - bp.plugins:insert { - route_id = route8.id, - name = "statsd", + bp.statsd_plugins:insert { + route_id = routes[8].id, config = { host = "127.0.0.1", port = UDP_PORT, @@ -194,14 +148,10 @@ for _, strategy in helpers.each_strategy() do } } - bp.plugins:insert { - name = "key-auth", - route_id = route9.id, - } + bp.key_auth_plugins:insert { route_id = routes[9].id } - bp.plugins:insert { - route_id = route9.id, - name = "statsd", + bp.statsd_plugins:insert { + route_id = routes[9].id, config = { host = "127.0.0.1", port = UDP_PORT, @@ -215,14 +165,10 @@ for _, strategy in helpers.each_strategy() do }, } - bp.plugins:insert { - name = "key-auth", - route_id = route10.id, - } + bp.key_auth_plugins:insert { route_id = routes[10].id } - bp.plugins:insert { - route_id = route10.id, - name = "statsd", + bp.statsd_plugins:insert { + route_id = routes[10].id, config = { host = "127.0.0.1", port = UDP_PORT, @@ -237,14 +183,10 @@ for _, strategy in helpers.each_strategy() do }, } - bp.plugins:insert { - name = "key-auth", - route_id = route11.id, - } + bp.key_auth_plugins:insert { route_id = routes[11].id } - bp.plugins:insert { - route_id = route11.id, - name = "statsd", + bp.statsd_plugins:insert { + route_id = routes[11].id, config = { host = "127.0.0.1", port = UDP_PORT, @@ -259,14 +201,10 @@ for _, strategy in helpers.each_strategy() do }, } - bp.plugins:insert { - name = "key-auth", - route_id = route12.id, - } + bp.key_auth_plugins:insert { route_id = routes[12].id } - bp.plugins:insert { - route_id = route12.id, - name = "statsd", + bp.statsd_plugins:insert { + route_id = routes[12].id, config = { host = "127.0.0.1", port = UDP_PORT, @@ -280,14 +218,10 @@ for _, strategy in helpers.each_strategy() do }, } - bp.plugins:insert { - name = "key-auth", - route_id = route13.id, - } + bp.key_auth_plugins:insert { route_id = routes[13].id } - bp.plugins:insert { - route_id = route13.id, - name = "statsd", + bp.statsd_plugins:insert { + route_id = routes[13].id, config = { host = "127.0.0.1", port = UDP_PORT, @@ -343,19 +277,19 @@ for _, strategy in helpers.each_strategy() do local ok, metrics = thread:join() assert.True(ok) - assert.contains("kong.stastd1.request.count:1|c", metrics) - assert.contains("kong.stastd1.latency:%d+|ms", metrics, true) - assert.contains("kong.stastd1.request.size:110|ms", metrics) - assert.contains("kong.stastd1.request.status.200:1|c", metrics) - assert.contains("kong.stastd1.request.status.total:1|c", metrics) - assert.contains("kong.stastd1.response.size:%d+|ms", metrics, true) - assert.contains("kong.stastd1.upstream_latency:%d*|ms", metrics, true) - assert.contains("kong.stastd1.kong_latency:%d*|ms", metrics, true) - assert.contains("kong.stastd1.user.uniques:robert|s", metrics) - assert.contains("kong.stastd1.user.robert.request.count:1|c", metrics) - assert.contains("kong.stastd1.user.robert.request.status.total:1|c", + assert.contains("kong.statsd1.request.count:1|c", metrics) + assert.contains("kong.statsd1.latency:%d+|ms", metrics, true) + assert.contains("kong.statsd1.request.size:110|ms", metrics) + assert.contains("kong.statsd1.request.status.200:1|c", metrics) + assert.contains("kong.statsd1.request.status.total:1|c", metrics) + assert.contains("kong.statsd1.response.size:%d+|ms", metrics, true) + assert.contains("kong.statsd1.upstream_latency:%d*|ms", metrics, true) + assert.contains("kong.statsd1.kong_latency:%d*|ms", metrics, true) + assert.contains("kong.statsd1.user.uniques:robert|s", metrics) + assert.contains("kong.statsd1.user.robert.request.count:1|c", metrics) + assert.contains("kong.statsd1.user.robert.request.status.total:1|c", metrics) - assert.contains("kong.stastd1.user.robert.request.status.200:1|c", + assert.contains("kong.statsd1.user.robert.request.status.200:1|c", metrics) end) it("logs over UDP with default metrics and new prefix", function() @@ -388,19 +322,19 @@ for _, strategy in helpers.each_strategy() do assert.res_status(200, response) local ok, metrics = thread:join() assert.True(ok) - assert.contains("prefix.stastd13.request.count:1|c", metrics) - assert.contains("prefix.stastd13.latency:%d+|ms", metrics, true) - assert.contains("prefix.stastd13.request.size:%d*|ms", metrics, true) - assert.contains("prefix.stastd13.request.status.200:1|c", metrics) - assert.contains("prefix.stastd13.request.status.total:1|c", metrics) - assert.contains("prefix.stastd13.response.size:%d+|ms", metrics, true) - assert.contains("prefix.stastd13.upstream_latency:%d*|ms", metrics, true) - assert.contains("prefix.stastd13.kong_latency:%d*|ms", metrics, true) - assert.contains("prefix.stastd13.user.uniques:robert|s", metrics) - assert.contains("prefix.stastd13.user.robert.request.count:1|c", metrics) - assert.contains("prefix.stastd13.user.robert.request.status.total:1|c", + assert.contains("prefix.statsd13.request.count:1|c", metrics) + assert.contains("prefix.statsd13.latency:%d+|ms", metrics, true) + assert.contains("prefix.statsd13.request.size:%d*|ms", metrics, true) + assert.contains("prefix.statsd13.request.status.200:1|c", metrics) + assert.contains("prefix.statsd13.request.status.total:1|c", metrics) + assert.contains("prefix.statsd13.response.size:%d+|ms", metrics, true) + assert.contains("prefix.statsd13.upstream_latency:%d*|ms", metrics, true) + assert.contains("prefix.statsd13.kong_latency:%d*|ms", metrics, true) + assert.contains("prefix.statsd13.user.uniques:robert|s", metrics) + assert.contains("prefix.statsd13.user.robert.request.count:1|c", metrics) + assert.contains("prefix.statsd13.user.robert.request.status.total:1|c", metrics) - assert.contains("prefix.stastd13.user.robert.request.status.200:1|c", + assert.contains("prefix.statsd13.user.robert.request.status.200:1|c", metrics) end) it("request_count", function() @@ -416,7 +350,7 @@ for _, strategy in helpers.each_strategy() do local ok, res = thread:join() assert.True(ok) - assert.equal("kong.stastd5.request.count:1|c", res) + assert.equal("kong.statsd5.request.count:1|c", res) end) it("status_count", function() local threads = require "llthreads2.ex" @@ -448,8 +382,8 @@ for _, strategy in helpers.each_strategy() do local ok, res = thread:join() assert.True(ok) - assert.contains("kong.stastd3.request.status.200:1|c", res) - assert.contains("kong.stastd3.request.status.total:1|c", res) + assert.contains("kong.statsd3.request.status.200:1|c", res) + assert.contains("kong.statsd3.request.status.total:1|c", res) end) it("request_size", function() local thread = helpers.udp_server(UDP_PORT) @@ -464,7 +398,7 @@ for _, strategy in helpers.each_strategy() do local ok, res = thread:join() assert.True(ok) - assert.matches("kong.stastd4.request.size:%d+|ms", res) + assert.matches("kong.statsd4.request.size:%d+|ms", res) end) it("latency", function() local thread = helpers.udp_server(UDP_PORT) @@ -479,7 +413,7 @@ for _, strategy in helpers.each_strategy() do local ok, res = thread:join() assert.True(ok) - assert.matches("kong.stastd2.latency:.*|ms", res) + assert.matches("kong.statsd2.latency:.*|ms", res) end) it("response_size", function() local thread = helpers.udp_server(UDP_PORT) @@ -494,7 +428,7 @@ for _, strategy in helpers.each_strategy() do local ok, res = thread:join() assert.True(ok) - assert.matches("kong.stastd6.response.size:%d+|ms", res) + assert.matches("kong.statsd6.response.size:%d+|ms", res) end) it("upstream_latency", function() local thread = helpers.udp_server(UDP_PORT) @@ -509,7 +443,7 @@ for _, strategy in helpers.each_strategy() do local ok, res = thread:join() assert.True(ok) - assert.matches("kong.stastd7.upstream_latency:.*|ms", res) + assert.matches("kong.statsd7.upstream_latency:.*|ms", res) end) it("kong_latency", function() local thread = helpers.udp_server(UDP_PORT) @@ -524,7 +458,7 @@ for _, strategy in helpers.each_strategy() do local ok, res = thread:join() assert.True(ok) - assert.matches("kong.stastd8.kong_latency:.*|ms", res) + assert.matches("kong.statsd8.kong_latency:.*|ms", res) end) it("unique_users", function() local thread = helpers.udp_server(UDP_PORT) @@ -539,7 +473,7 @@ for _, strategy in helpers.each_strategy() do local ok, res = thread:join() assert.True(ok) - assert.matches("kong.stastd9.user.uniques:robert|s", res) + assert.matches("kong.statsd9.user.uniques:robert|s", res) end) it("status_count_per_user", function() local threads = require "llthreads2.ex" @@ -571,8 +505,8 @@ for _, strategy in helpers.each_strategy() do local ok, res = thread:join() assert.True(ok) - assert.contains("kong.stastd10.user.robert.request.status.200:1|c", res) - assert.contains("kong.stastd10.user.robert.request.status.total:1|c", res) + assert.contains("kong.statsd10.user.robert.request.status.200:1|c", res) + assert.contains("kong.statsd10.user.robert.request.status.total:1|c", res) end) it("request_per_user", function() local thread = helpers.udp_server(UDP_PORT) @@ -587,7 +521,7 @@ for _, strategy in helpers.each_strategy() do local ok, res = thread:join() assert.True(ok) - assert.matches("kong.stastd11.user.bob.request.count:1|c", res) + assert.matches("kong.statsd11.user.bob.request.count:1|c", res) end) it("latency as gauge", function() local thread = helpers.udp_server(UDP_PORT) @@ -602,7 +536,7 @@ for _, strategy in helpers.each_strategy() do local ok, res = thread:join() assert.True(ok) - assert.matches("kong%.stastd12%.latency:%d+|g", res) + assert.matches("kong%.statsd12%.latency:%d+|g", res) end) end) end) diff --git a/spec/03-plugins/24-rate-limiting/02-policies_spec.lua b/spec/03-plugins/24-rate-limiting/02-policies_spec.lua index 1a641a0200b..e6194e84bd9 100644 --- a/spec/03-plugins/24-rate-limiting/02-policies_spec.lua +++ b/spec/03-plugins/24-rate-limiting/02-policies_spec.lua @@ -9,8 +9,8 @@ for _, strategy in helpers.each_strategy() do describe("cluster", function() local cluster_policy = policies.cluster - local route_id = uuid() local identifier = uuid() + local conf = { route_id = uuid(), service_id = uuid() } local db local dao @@ -33,7 +33,7 @@ for _, strategy in helpers.each_strategy() do local periods = timestamp.get_timestamps(current_timestamp) for period in pairs(periods) do - local metric = assert(cluster_policy.usage(nil, route_id, identifier, + local metric = assert(cluster_policy.usage(conf, identifier, current_timestamp, period)) assert.equal(0, metric) end @@ -53,21 +53,21 @@ for _, strategy in helpers.each_strategy() do } -- First increment - assert(cluster_policy.increment(nil, limits, route_id, identifier, current_timestamp, 1)) + assert(cluster_policy.increment(conf, limits, identifier, current_timestamp, 1)) -- First select for period in pairs(periods) do - local metric = assert(cluster_policy.usage(nil, route_id, identifier, + local metric = assert(cluster_policy.usage(conf, identifier, current_timestamp, period)) assert.equal(1, metric) end -- Second increment - assert(cluster_policy.increment(nil, limits, route_id, identifier, current_timestamp, 1)) + assert(cluster_policy.increment(conf, limits, identifier, current_timestamp, 1)) -- Second select for period in pairs(periods) do - local metric = assert(cluster_policy.usage(nil, route_id, identifier, + local metric = assert(cluster_policy.usage(conf, identifier, current_timestamp, period)) assert.equal(2, metric) end @@ -77,7 +77,7 @@ for _, strategy in helpers.each_strategy() do periods = timestamp.get_timestamps(current_timestamp) -- Third increment - assert(cluster_policy.increment(nil, limits, route_id, identifier, current_timestamp, 1)) + assert(cluster_policy.increment(conf, limits, identifier, current_timestamp, 1)) -- Third select with 1 second delay for period in pairs(periods) do @@ -86,7 +86,7 @@ for _, strategy in helpers.each_strategy() do expected_value = 1 end - local metric = assert(cluster_policy.usage(nil, route_id, identifier, + local metric = assert(cluster_policy.usage(conf, identifier, current_timestamp, period)) assert.equal(expected_value, metric) end diff --git a/spec/03-plugins/24-rate-limiting/04-access_spec.lua b/spec/03-plugins/24-rate-limiting/04-access_spec.lua index e7445040ed8..d8b448a7bcb 100644 --- a/spec/03-plugins/24-rate-limiting/04-access_spec.lua +++ b/spec/03-plugins/24-rate-limiting/04-access_spec.lua @@ -12,6 +12,12 @@ local REDIS_DATABASE = 1 local SLEEP_TIME = 1 +local fmt = string.format + + +local proxy_client = helpers.proxy_client + + local function wait(second_offset) -- If the minute elapses in the middle of the test, then the test will -- fail. So we give it this test 30 seconds to execute, and if the second @@ -23,9 +29,6 @@ local function wait(second_offset) end -wait() -- Wait before starting - - local function flush_redis() local redis = require "resty.redis" local red = redis:new() @@ -54,7 +57,7 @@ end for _, strategy in helpers.each_strategy() do for _, policy in ipairs({"local", "cluster", "redis"}) do - describe("#flaky Plugin: rate-limiting (access) with policy: " .. policy .. " [#" .. strategy .. "]", function() + describe(fmt("#flaky Plugin: rate-limiting (access) with policy: %s [#%s]", policy, strategy), function() local bp local db local dao @@ -64,38 +67,40 @@ for _, strategy in helpers.each_strategy() do flush_redis() bp, db, dao = helpers.get_db_utils(strategy) + assert(db:truncate()) + dao:truncate_tables() + assert(dao:run_migrations()) local consumer1 = bp.consumers:insert { custom_id = "provider_123", } - assert(dao.keyauth_credentials:insert { + bp.keyauth_credentials:insert { key = "apikey122", consumer_id = consumer1.id, - }) + } local consumer2 = bp.consumers:insert { custom_id = "provider_124", } - assert(dao.keyauth_credentials:insert { + bp.keyauth_credentials:insert { key = "apikey123", consumer_id = consumer2.id, - }) + } - assert(dao.keyauth_credentials:insert { + bp.keyauth_credentials:insert { key = "apikey333", consumer_id = consumer2.id, - }) + } local route1 = bp.routes:insert { hosts = { "test1.com" }, } - bp.plugins:insert { - name = "rate-limiting", + bp.rate_limiting_plugins:insert({ route_id = route1.id, - config = { + config = { policy = policy, minute = 6, fault_tolerant = false, @@ -104,16 +109,15 @@ for _, strategy in helpers.each_strategy() do redis_password = REDIS_PASSWORD, redis_database = REDIS_DATABASE, } - } + }) local route2 = bp.routes:insert { hosts = { "test2.com" }, } - bp.plugins:insert { - name = "rate-limiting", + bp.rate_limiting_plugins:insert({ route_id = route2.id, - config = { + config = { minute = 3, hour = 5, fault_tolerant = false, @@ -123,7 +127,7 @@ for _, strategy in helpers.each_strategy() do redis_password = REDIS_PASSWORD, redis_database = REDIS_DATABASE, } - } + }) local route3 = bp.routes:insert { hosts = { "test3.com" }, @@ -134,10 +138,9 @@ for _, strategy in helpers.each_strategy() do route_id = route3.id, } - bp.plugins:insert { - name = "rate-limiting", + bp.rate_limiting_plugins:insert({ route_id = route3.id, - config = { + config = { minute = 6, limit_by = "credential", fault_tolerant = false, @@ -147,11 +150,10 @@ for _, strategy in helpers.each_strategy() do redis_password = REDIS_PASSWORD, redis_database = REDIS_DATABASE, } - } + }) - bp.plugins:insert { - name = "rate-limiting", - route_id = route3.id, + bp.rate_limiting_plugins:insert({ + route_id = route3.id, consumer_id = consumer1.id, config = { minute = 8, @@ -162,7 +164,7 @@ for _, strategy in helpers.each_strategy() do redis_password = REDIS_PASSWORD, redis_database = REDIS_DATABASE } - } + }) local route4 = bp.routes:insert { hosts = { "test4.com" }, @@ -173,9 +175,8 @@ for _, strategy in helpers.each_strategy() do route_id = route4.id, } - bp.plugins:insert { - name = "rate-limiting", - route_id = route4.id, + bp.rate_limiting_plugins:insert({ + route_id = route4.id, consumer_id = consumer1.id, config = { minute = 6, @@ -186,14 +187,13 @@ for _, strategy in helpers.each_strategy() do redis_password = REDIS_PASSWORD, redis_database = REDIS_DATABASE, }, - } + }) local route5 = bp.routes:insert { hosts = { "test5.com" }, } - bp.plugins:insert { - name = "rate-limiting", + bp.rate_limiting_plugins:insert({ route_id = route5.id, config = { policy = policy, @@ -205,46 +205,54 @@ for _, strategy in helpers.each_strategy() do redis_password = REDIS_PASSWORD, redis_database = REDIS_DATABASE, }, + }) + + local service = bp.services:insert() + bp.routes:insert { + hosts = { "test-service1.com" }, + service = service, + } + bp.routes:insert { + hosts = { "test-service2.com" }, + service = service, } + bp.rate_limiting_plugins:insert({ + service_id = service.id, + config = { + policy = policy, + minute = 6, + fault_tolerant = false, + redis_host = REDIS_HOST, + redis_port = REDIS_PORT, + redis_password = REDIS_PASSWORD, + redis_database = REDIS_DATABASE, + } + }) + assert(helpers.start_kong({ database = strategy, nginx_conf = "spec/fixtures/custom_nginx.template", })) end) + teardown(function() helpers.stop_kong() + dao:drop_schema() + assert(db:truncate()) + assert(dao:run_migrations()) end) - local proxy_client - local admin_client - before_each(function() wait(3) - proxy_client = helpers.proxy_client() - admin_client = helpers.admin_client() - end) - - after_each(function() - if proxy_client then - proxy_client:close() - end - - if admin_client then - admin_client:close() - end end) describe("Without authentication (IP address)", function() it("blocks if exceeding limit", function() for i = 1, 6 do - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200", - headers = { - ["Host"] = "test1.com" - } + local res = proxy_client():get("/status/200", { + headers = { Host = "test1.com" }, }) ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit @@ -255,12 +263,40 @@ for _, strategy in helpers.each_strategy() do end -- Additonal request, while limit is 6/minute - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200", - headers = { - ["Host"] = "test1.com" - } + local res = proxy_client():get("/status/200", { + headers = { Host = "test1.com" }, + }) + local body = assert.res_status(429, res) + local json = cjson.decode(body) + assert.same({ message = "API rate limit exceeded" }, json) + end) + + it("counts against the same service register from different routes", function() + for i = 1, 3 do + local res = proxy_client():get("/status/200", { + headers = { Host = "test-service1.com" }, + }) + ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit + + assert.res_status(200, res) + assert.are.same(6, tonumber(res.headers["x-ratelimit-limit-minute"])) + assert.are.same(6 - i, tonumber(res.headers["x-ratelimit-remaining-minute"])) + end + + for i = 4, 6 do + local res = proxy_client():get("/status/200", { + headers = { Host = "test-service2.com" }, + }) + ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit + + assert.res_status(200, res) + assert.are.same(6, tonumber(res.headers["x-ratelimit-limit-minute"])) + assert.are.same(6 - i, tonumber(res.headers["x-ratelimit-remaining-minute"])) + end + + -- Additonal request, while limit is 6/minute + local res = proxy_client():get("/status/200", { + headers = { Host = "test-service1.com" }, }) local body = assert.res_status(429, res) local json = cjson.decode(body) @@ -274,12 +310,8 @@ for _, strategy in helpers.each_strategy() do } for i = 1, 3 do - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200", - headers = { - ["Host"] = "test2.com" - } + local res = proxy_client():get("/status/200", { + headers = { Host = "test2.com" }, }) ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit @@ -291,12 +323,9 @@ for _, strategy in helpers.each_strategy() do assert.are.same(limits.hour - i, tonumber(res.headers["x-ratelimit-remaining-hour"])) end - local res = assert(helpers.proxy_client():send { - method = "GET", + local res = proxy_client():get("/status/200", { path = "/status/200", - headers = { - ["Host"] = "test2.com" - } + headers = { Host = "test2.com" }, }) local body = assert.res_status(429, res) local json = cjson.decode(body) @@ -309,12 +338,8 @@ for _, strategy in helpers.each_strategy() do describe("API-specific plugin", function() it("blocks if exceeding limit", function() for i = 1, 6 do - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200?apikey=apikey123", - headers = { - ["Host"] = "test3.com" - } + local res = proxy_client():get("/status/200?apikey=apikey123", { + headers = { Host = "test3.com" }, }) ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit @@ -325,37 +350,25 @@ for _, strategy in helpers.each_strategy() do end -- Third query, while limit is 2/minute - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200?apikey=apikey123", - headers = { - ["Host"] = "test3.com" - } + local res = proxy_client():get("/status/200?apikey=apikey123", { + headers = { Host = "test3.com" }, }) local body = assert.res_status(429, res) local json = cjson.decode(body) assert.same({ message = "API rate limit exceeded" }, json) -- Using a different key of the same consumer works - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200?apikey=apikey333", - headers = { - ["Host"] = "test3.com" - } + local res = proxy_client():get("/status/200?apikey=apikey333", { + headers = { Host = "test3.com" }, }) assert.res_status(200, res) end) end) - describe("Plugin customized for specific consumer", function() + describe("Plugin customized for specific consumer and route", function() it("blocks if exceeding limit", function() for i = 1, 8 do - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200?apikey=apikey122", - headers = { - ["Host"] = "test3.com" - } + local res = proxy_client():get("/status/200?apikey=apikey122", { + headers = { Host = "test3.com" }, }) ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit @@ -365,12 +378,8 @@ for _, strategy in helpers.each_strategy() do assert.are.same(8 - i, tonumber(res.headers["x-ratelimit-remaining-minute"])) end - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200?apikey=apikey122", - headers = { - ["Host"] = "test3.com" - } + local res = proxy_client():get("/status/200?apikey=apikey122", { + headers = { Host = "test3.com" }, }) local body = assert.res_status(429, res) local json = cjson.decode(body) @@ -378,12 +387,8 @@ for _, strategy in helpers.each_strategy() do end) it("blocks if the only rate-limiting plugin existing is per consumer and not per API", function() for i = 1, 6 do - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200?apikey=apikey122", - headers = { - ["Host"] = "test4.com" - } + local res = proxy_client():get("/status/200?apikey=apikey122", { + headers = { Host = "test4.com" }, }) ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit @@ -393,12 +398,8 @@ for _, strategy in helpers.each_strategy() do assert.are.same(6 - i, tonumber(res.headers["x-ratelimit-remaining-minute"])) end - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200?apikey=apikey122", - headers = { - ["Host"] = "test4.com" - } + local res = proxy_client():get("/status/200?apikey=apikey122", { + headers = { Host = "test4.com" }, }) local body = assert.res_status(429, res) local json = cjson.decode(body) @@ -409,12 +410,8 @@ for _, strategy in helpers.each_strategy() do describe("Config with hide_client_headers", function() it("does not send rate-limit headers when hide_client_headers==true", function() - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200", - headers = { - ["Host"] = "test5.com" - } + local res = proxy_client():get("/status/200", { + headers = { Host = "test5.com" }, }) assert.res_status(200, res) @@ -432,21 +429,20 @@ for _, strategy in helpers.each_strategy() do assert(db:truncate()) dao:truncate_tables() - local route1 = assert(db.routes:insert { + local route1 = bp.routes:insert { hosts = { "failtest1.com" }, - }) + } - bp.plugins:insert { - name = "rate-limiting", + bp.rate_limiting_plugins:insert { route_id = route1.id, config = { minute = 6, fault_tolerant = false } } - local route2 = assert(db.routes:insert { + local route2 = bp.routes:insert { hosts = { "failtest2.com" }, - }) + } - bp.plugins:insert { + bp.rate_limiting_plugins:insert { name = "rate-limiting", route_id = route2.id, config = { minute = 6, fault_tolerant = true }, @@ -465,12 +461,8 @@ for _, strategy in helpers.each_strategy() do end) it("does not work if an error occurs", function() - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200", - headers = { - ["Host"] = "failtest1.com" - } + local res = proxy_client():get("/status/200", { + headers = { Host = "failtest1.com" }, }) assert.res_status(200, res) assert.are.same(6, tonumber(res.headers["x-ratelimit-limit-minute"])) @@ -480,24 +472,16 @@ for _, strategy in helpers.each_strategy() do assert(dao.db:drop_table("ratelimiting_metrics")) -- Make another request - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200", - headers = { - ["Host"] = "failtest1.com" - } + local res = proxy_client():get("/status/200", { + headers = { Host = "failtest1.com" }, }) local body = assert.res_status(500, res) local json = cjson.decode(body) assert.same({ message = "An unexpected error occurred" }, json) end) it("keeps working if an error occurs", function() - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200", - headers = { - ["Host"] = "failtest2.com" - } + local res = proxy_client():get("/status/200", { + headers = { Host = "failtest2.com" }, }) assert.res_status(200, res) assert.are.same(6, tonumber(res.headers["x-ratelimit-limit-minute"])) @@ -507,12 +491,8 @@ for _, strategy in helpers.each_strategy() do assert(dao.db:drop_table("ratelimiting_metrics")) -- Make another request - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200", - headers = { - ["Host"] = "failtest2.com" - } + local res = proxy_client():get("/status/200", { + headers = { Host = "failtest2.com" }, }) assert.res_status(200, res) assert.falsy(res.headers["x-ratelimit-limit-minute"]) @@ -528,31 +508,30 @@ for _, strategy in helpers.each_strategy() do local service1 = bp.services:insert() - local route1 = assert(db.routes:insert { + local route1 = bp.routes:insert { hosts = { "failtest3.com" }, protocols = { "http", "https" }, service = service1 - }) + } - assert(dao.plugins:insert { - name = "rate-limiting", + bp.rate_limiting_plugins:insert { route_id = route1.id, config = { minute = 6, policy = policy, redis_host = "5.5.5.5", fault_tolerant = false }, - }) + } local service2 = bp.services:insert() - local route2 = assert(db.routes:insert { + local route2 = bp.routes:insert { hosts = { "failtest4.com" }, protocols = { "http", "https" }, service = service2 - }) + } - assert(dao.plugins:insert { + bp.rate_limiting_plugins:insert { name = "rate-limiting", route_id = route2.id, config = { minute = 6, policy = policy, redis_host = "5.5.5.5", fault_tolerant = true }, - }) + } assert(helpers.start_kong({ database = strategy, @@ -562,12 +541,8 @@ for _, strategy in helpers.each_strategy() do it("does not work if an error occurs", function() -- Make another request - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200", - headers = { - ["Host"] = "failtest3.com" - } + local res = proxy_client():get("/status/200", { + headers = { Host = "failtest3.com" }, }) local body = assert.res_status(500, res) local json = cjson.decode(body) @@ -575,12 +550,8 @@ for _, strategy in helpers.each_strategy() do end) it("keeps working if an error occurs", function() -- Make another request - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200", - headers = { - ["Host"] = "failtest4.com" - } + local res = proxy_client():get("/status/200", { + headers = { Host = "failtest4.com" }, }) assert.res_status(200, res) assert.falsy(res.headers["x-ratelimit-limit-minute"]) @@ -598,11 +569,10 @@ for _, strategy in helpers.each_strategy() do local bp = helpers.get_db_utils(strategy) route = bp.routes:insert { - hosts = { "expire1.com" }, + hosts = { "expire1.com" }, } - bp.plugins:insert { - name = "rate-limiting", + bp.rate_limiting_plugins:insert { route_id = route.id, config = { minute = 6, @@ -622,12 +592,8 @@ for _, strategy in helpers.each_strategy() do end) describe("expires a counter", function() - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200", - headers = { - ["Host"] = "expire1.com" - } + local res = proxy_client():get("/status/200", { + headers = { Host = "expire1.com" }, }) ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit @@ -638,12 +604,8 @@ for _, strategy in helpers.each_strategy() do ngx.sleep(61) -- Wait for counter to expire - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200", - headers = { - ["Host"] = "expire1.com" - } + local res = proxy_client():get("/status/200", { + headers = { Host = "expire1.com" } }) ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit @@ -654,5 +616,149 @@ for _, strategy in helpers.each_strategy() do end) end) end) + + describe(fmt("#flaky Plugin: rate-limiting (access - global for single consumer) with policy: %s [#%s]", policy, strategy), function() + local bp + local db + local dao + setup(function() + helpers.kill_all() + flush_redis() + bp, db, dao = helpers.get_db_utils(strategy) + assert(db:truncate()) + dao:truncate_tables() + assert(dao:run_migrations()) + + local consumer = bp.consumers:insert { + custom_id = "provider_125", + } + + bp.key_auth_plugins:insert() + + bp.keyauth_credentials:insert { + key = "apikey125", + consumer_id = consumer.id, + } + + -- just consumer, no no route or service + bp.rate_limiting_plugins:insert({ + consumer_id = consumer.id, + config = { + limit_by = "credential", + policy = policy, + minute = 6, + fault_tolerant = false, + redis_host = REDIS_HOST, + redis_port = REDIS_PORT, + redis_password = REDIS_PASSWORD, + redis_database = REDIS_DATABASE, + } + }) + + for i = 1, 6 do + bp.routes:insert({ hosts = { fmt("test%d.com", i) } }) + end + + assert(helpers.start_kong({ + database = strategy, + nginx_conf = "spec/fixtures/custom_nginx.template", + })) + end) + + teardown(function() + helpers.kill_all() + dao:drop_schema() + assert(db:truncate()) + assert(dao:run_migrations()) + end) + + it("blocks when the consumer exceeds their quota, no matter what service/route used", function() + for i = 1, 6 do + local res = proxy_client():get("/status/200?apikey=apikey125", { + headers = { Host = fmt("test%d.com", i) }, + }) + + ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit + + assert.are.same(6, tonumber(res.headers["x-ratelimit-limit-minute"])) + assert.are.same(6 - i, tonumber(res.headers["x-ratelimit-remaining-minute"])) + end + + -- Additonal request, while limit is 6/minute + local res = proxy_client():get("/status/200?apikey=apikey125", { + headers = { Host = "test1.com" }, + }) + local body = assert.res_status(429, res) + local json = cjson.decode(body) + assert.same({ message = "API rate limit exceeded" }, json) + end) + end) + + describe(fmt("#flaky Plugin: rate-limiting (access - global) with policy: %s [#%s]", policy, strategy), function() + local bp + local db + local dao + setup(function() + helpers.kill_all() + flush_redis() + bp, db, dao = helpers.get_db_utils(strategy) + assert(db:truncate()) + dao:truncate_tables() + assert(dao:run_migrations()) + + -- global plugin (not attached to route, service or consumer) + bp.rate_limiting_plugins:insert({ + config = { + policy = policy, + minute = 6, + fault_tolerant = false, + redis_host = REDIS_HOST, + redis_port = REDIS_PORT, + redis_password = REDIS_PASSWORD, + redis_database = REDIS_DATABASE, + } + }) + + for i = 1, 6 do + bp.routes:insert({ hosts = { fmt("test%d.com", i) } }) + end + + assert(helpers.start_kong({ + database = strategy, + nginx_conf = "spec/fixtures/custom_nginx.template", + })) + end) + + teardown(function() + helpers.kill_all() + dao:drop_schema() + assert(db:truncate()) + assert(dao:run_migrations()) + end) + + it("blocks if exceeding limit", function() + for i = 1, 6 do + local res = proxy_client():get("/status/200", { + headers = { Host = fmt("test%d.com", i) }, + }) + + assert.res_status(200, res) + assert.are.same(6, tonumber(res.headers["x-ratelimit-limit-minute"])) + assert.are.same(6 - i, tonumber(res.headers["x-ratelimit-remaining-minute"])) + end + + ngx.sleep(SLEEP_TIME) + + -- Additonal request, while limit is 6/minute + local res = proxy_client():get("/status/200", { + headers = { Host = "test1.com" }, + }) + local body = assert.res_status(429, res) + local json = cjson.decode(body) + assert.same({ message = "API rate limit exceeded" }, json) + end) + end) end end + + diff --git a/spec/03-plugins/25-response-rate-limiting/02-policies_spec.lua b/spec/03-plugins/25-response-rate-limiting/02-policies_spec.lua index 60081809a56..7bb654c57a5 100644 --- a/spec/03-plugins/25-response-rate-limiting/02-policies_spec.lua +++ b/spec/03-plugins/25-response-rate-limiting/02-policies_spec.lua @@ -9,7 +9,7 @@ for _, strategy in helpers.each_strategy() do describe("cluster", function() local cluster_policy = policies.cluster - local route_id = uuid() + local conf = { route_id = uuid(), service_id = uuid() } local identifier = uuid() local db @@ -33,7 +33,7 @@ for _, strategy in helpers.each_strategy() do local periods = timestamp.get_timestamps(current_timestamp) for period in pairs(periods) do - local metric = assert(cluster_policy.usage(nil, route_id, identifier, + local metric = assert(cluster_policy.usage(conf, identifier, current_timestamp, period, "video")) assert.equal(0, metric) end @@ -44,21 +44,21 @@ for _, strategy in helpers.each_strategy() do local periods = timestamp.get_timestamps(current_timestamp) -- First increment - assert(cluster_policy.increment(nil, route_id, identifier, current_timestamp, 1, "video")) + assert(cluster_policy.increment(conf, identifier, current_timestamp, 1, "video")) -- First select for period in pairs(periods) do - local metric = assert(cluster_policy.usage(nil, route_id, identifier, + local metric = assert(cluster_policy.usage(conf, identifier, current_timestamp, period, "video")) assert.equal(1, metric) end -- Second increment - assert(cluster_policy.increment(nil, route_id, identifier, current_timestamp, 1, "video")) + assert(cluster_policy.increment(conf, identifier, current_timestamp, 1, "video")) -- Second select for period in pairs(periods) do - local metric = assert(cluster_policy.usage(nil, route_id, identifier, + local metric = assert(cluster_policy.usage(conf, identifier, current_timestamp, period, "video")) assert.equal(2, metric) end @@ -68,7 +68,7 @@ for _, strategy in helpers.each_strategy() do periods = timestamp.get_timestamps(current_timestamp) -- Third increment - assert(cluster_policy.increment(nil, route_id, identifier, current_timestamp, 1, "video")) + assert(cluster_policy.increment(conf, identifier, current_timestamp, 1, "video")) -- Third select with 1 second delay for period in pairs(periods) do @@ -79,7 +79,7 @@ for _, strategy in helpers.each_strategy() do expected_value = 1 end - local metric = assert(cluster_policy.usage(nil, route_id, identifier, + local metric = assert(cluster_policy.usage(conf, identifier, current_timestamp, period, "video")) assert.equal(expected_value, metric) end diff --git a/spec/03-plugins/25-response-rate-limiting/04-access_spec.lua b/spec/03-plugins/25-response-rate-limiting/04-access_spec.lua index 5369c20aac5..827248da428 100644 --- a/spec/03-plugins/25-response-rate-limiting/04-access_spec.lua +++ b/spec/03-plugins/25-response-rate-limiting/04-access_spec.lua @@ -9,7 +9,13 @@ local REDIS_PASSWORD = "" local REDIS_DATABASE = 1 -local SLEEP_TIME = 1 +local SLEEP_TIME = 1 + + +local fmt = string.format + + +local proxy_client = helpers.proxy_client local function wait(second_offset) @@ -54,7 +60,7 @@ end for _, strategy in helpers.each_strategy() do for i, policy in ipairs({"local", "cluster", "redis"}) do - aescribe("#flaky Plugin: response-ratelimiting (access) with policy: " .. policy .. "[#" .. strategy .. "]", function() + describe(fmt("#flaky Plugin: response-ratelimiting (access) with policy: %s [#%s]", policy, strategy), function() local db local dao local bp @@ -64,34 +70,33 @@ for _, strategy in helpers.each_strategy() do flush_redis() - local consumer1 = assert(dao.consumers:insert {custom_id = "provider_123"}) - assert(dao.keyauth_credentials:insert { + local consumer1 = bp.consumers:insert {custom_id = "provider_123"} + bp.keyauth_credentials:insert { key = "apikey123", consumer_id = consumer1.id - }) + } - local consumer2 = assert(dao.consumers:insert {custom_id = "provider_124"}) - assert(dao.keyauth_credentials:insert { + local consumer2 = bp.consumers:insert {custom_id = "provider_124"} + bp.keyauth_credentials:insert { key = "apikey124", consumer_id = consumer2.id - }) + } - local consumer3 = assert(dao.consumers:insert {custom_id = "provider_125"}) - assert(dao.keyauth_credentials:insert { + local consumer3 = bp.consumers:insert {custom_id = "provider_125"} + bp.keyauth_credentials:insert { key = "apikey125", consumer_id = consumer3.id - }) + } local service1 = bp.services:insert() - local route1 = assert(db.routes:insert { + local route1 = bp.routes:insert { hosts = { "test1.com" }, protocols = { "http", "https" }, service = service1 - }) + } - assert(dao.plugins:insert { - name = "response-ratelimiting", + bp.response_ratelimiting_plugins:insert({ route_id = route1.id, config = { fault_tolerant = false, @@ -101,19 +106,18 @@ for _, strategy in helpers.each_strategy() do redis_password = REDIS_PASSWORD, redis_database = REDIS_DATABASE, limits = { video = { minute = 6 } }, - } + }, }) local service2 = bp.services:insert() - local route2 = assert(db.routes:insert { + local route2 = bp.routes:insert { hosts = { "test2.com" }, protocols = { "http", "https" }, service = service2 - }) + } - assert(dao.plugins:insert { - name = "response-ratelimiting", + bp.response_ratelimiting_plugins:insert({ route_id = route2.id, config = { fault_tolerant = false, @@ -124,31 +128,29 @@ for _, strategy in helpers.each_strategy() do redis_database = REDIS_DATABASE, limits = { video = { minute = 6, hour = 10 }, image = { minute = 4 } }, - } + }, }) local service3 = bp.services:insert() - local route3 = assert(db.routes:insert { + local route3 = bp.routes:insert { hosts = { "test3.com" }, protocols = { "http", "https" }, service = service3 - }) + } - assert(dao.plugins:insert { + bp.plugins:insert { name = "key-auth", route_id = route3.id, - }) + } - assert(dao.plugins:insert { - name = "response-ratelimiting", + bp.response_ratelimiting_plugins:insert({ route_id = route3.id, config = { limits = { video = { minute = 6 } } }, }) - assert(dao.plugins:insert { - name = "response-ratelimiting", - route_id = route3.id, + bp.response_ratelimiting_plugins:insert({ + route_id = route3.id, consumer_id = consumer1.id, config = { fault_tolerant = false, @@ -158,19 +160,18 @@ for _, strategy in helpers.each_strategy() do redis_password = REDIS_PASSWORD, redis_database = REDIS_DATABASE, limits = { video = { minute = 2 } }, - } + }, }) local service4 = bp.services:insert() - local route4 = assert(db.routes:insert { + local route4 = bp.routes:insert { hosts = { "test4.com" }, protocols = { "http", "https" }, service = service4 - }) + } - assert(dao.plugins:insert { - name = "response-ratelimiting", + bp.response_ratelimiting_plugins:insert({ route_id = route4.id, config = { fault_tolerant = false, @@ -185,14 +186,13 @@ for _, strategy in helpers.each_strategy() do local service7 = bp.services:insert() - local route7 = assert(db.routes:insert { + local route7 = bp.routes:insert { hosts = { "test7.com" }, protocols = { "http", "https" }, service = service7 - }) + } - assert(dao.plugins:insert { - name = "response-ratelimiting", + bp.response_ratelimiting_plugins:insert({ route_id = route7.id, config = { fault_tolerant = false, @@ -216,14 +216,13 @@ for _, strategy in helpers.each_strategy() do local service8 = bp.services:insert() - local route8 = assert(db.routes:insert { + local route8 = bp.routes:insert { hosts = { "test8.com" }, protocols = { "http", "https" }, service = service8 - }) + } - assert(dao.plugins:insert { - name = "response-ratelimiting", + bp.response_ratelimiting_plugins:insert({ route_id = route8.id, config = { fault_tolerant = false, @@ -239,14 +238,13 @@ for _, strategy in helpers.each_strategy() do local service9 = bp.services:insert() - local route9 = assert(db.routes:insert { + local route9 = bp.routes:insert { hosts = { "test9.com" }, protocols = { "http", "https" }, service = service9 - }) + } - assert(dao.plugins:insert { - name = "response-ratelimiting", + bp.response_ratelimiting_plugins:insert({ route_id = route9.id, config = { fault_tolerant = false, @@ -260,6 +258,30 @@ for _, strategy in helpers.each_strategy() do } }) + + local service10 = bp.services:insert() + bp.routes:insert { + hosts = { "test-service1.com" }, + service = service10, + } + bp.routes:insert { + hosts = { "test-service2.com" }, + service = service10, + } + + bp.response_ratelimiting_plugins:insert({ + service_id = service10.id, + config = { + fault_tolerant = false, + policy = policy, + redis_host = REDIS_HOST, + redis_port = REDIS_PORT, + redis_password = REDIS_PASSWORD, + redis_database = REDIS_DATABASE, + limits = { video = { minute = 6 } }, + } + }) + assert(helpers.start_kong({ database = strategy, nginx_conf = "spec/fixtures/custom_nginx.template", @@ -270,34 +292,15 @@ for _, strategy in helpers.each_strategy() do helpers.stop_kong() end) - local proxy_client - local admin_client - before_each(function() wait(1) - proxy_client = helpers.proxy_client() - admin_client = helpers.admin_client() - end) - - after_each(function() - if proxy_client then - proxy_client:close() - end - - if admin_client then - admin_client:close() - end end) describe("Without authentication (IP address)", function() it("blocks if exceeding limit", function() for i = 1, 6 do - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/response-headers?x-kong-limit=video=1, test=5", - headers = { - ["Host"] = "test1.com" - } + local res = proxy_client():get("/response-headers?x-kong-limit=video=1, test=5", { + headers = { Host = "test1.com" }, }) ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit @@ -307,12 +310,40 @@ for _, strategy in helpers.each_strategy() do assert.equal(6 - i, tonumber(res.headers["x-ratelimit-remaining-video-minute"])) end - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/response-headers?x-kong-limit=video=1", - headers = { - ["Host"] = "test1.com" - } + local res = proxy_client():get("/response-headers?x-kong-limit=video=1", { + headers = { Host = "test1.com" }, + }) + + local body = assert.res_status(429, res) + assert.equal([[]], body) + end) + + it("counts against the same service register from different routes", function() + for i = 1, 3 do + + local res = proxy_client():get("/response-headers?x-kong-limit=video=1, test=5", { + headers = { Host = "test-service1.com" }, + }) + ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit + + assert.res_status(200, res) + assert.equal(6, tonumber(res.headers["x-ratelimit-limit-video-minute"])) + assert.equal(6 - i, tonumber(res.headers["x-ratelimit-remaining-video-minute"])) + end + + for i = 4, 6 do + local res = proxy_client():get("/response-headers?x-kong-limit=video=1, test=5", { + headers = { Host = "test-service2.com" }, + }) + ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit + assert.res_status(200, res) + assert.equal(6, tonumber(res.headers["x-ratelimit-limit-video-minute"])) + assert.equal(6 - i, tonumber(res.headers["x-ratelimit-remaining-video-minute"])) + end + + -- Additonal request, while limit is 6/minute + local res = proxy_client():get("/response-headers?x-kong-limit=video=1, test=5", { + headers = { Host = "test-service1.com" }, }) local body = assert.res_status(429, res) assert.equal([[]], body) @@ -320,12 +351,9 @@ for _, strategy in helpers.each_strategy() do it("handles multiple limits", function() for i = 1, 3 do - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/response-headers?x-kong-limit=video=2, image=1", - headers = { - ["Host"] = "test2.com" - } + + local res = proxy_client():get("/response-headers?x-kong-limit=video=2, image=1", { + headers = { Host = "test2.com" }, }) ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit @@ -339,13 +367,10 @@ for _, strategy in helpers.each_strategy() do assert.equal(4 - i, tonumber(res.headers["x-ratelimit-remaining-image-minute"])) end - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/response-headers?x-kong-limit=video=2, image=1", - headers = { - ["Host"] = "test2.com" - } + local res = proxy_client():get("/response-headers?x-kong-limit=video=2, image=1", { + headers = { Host = "test2.com" }, }) + local body = assert.res_status(429, res) assert.equal([[]], body) assert.equal(0, tonumber(res.headers["x-ratelimit-remaining-video-minute"])) @@ -356,14 +381,10 @@ for _, strategy in helpers.each_strategy() do describe("With authentication", function() describe("API-specific plugin", function() - it("blocks if exceeding limit and a per consumer setting", function() + it("blocks if exceeding limit and a per consumer & route setting", function() for i = 1, 2 do - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/response-headers?apikey=apikey123&x-kong-limit=video=1", - headers = { - ["Host"] = "test3.com" - } + local res = proxy_client():get("/response-headers?apikey=apikey123&x-kong-limit=video=1", { + headers = { Host = "test3.com" }, }) ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit @@ -374,12 +395,8 @@ for _, strategy in helpers.each_strategy() do end -- Third query, while limit is 2/minute - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/response-headers?apikey=apikey123&x-kong-limit=video=1", - headers = { - ["Host"] = "test3.com" - } + local res = proxy_client():get("/response-headers?apikey=apikey123&x-kong-limit=video=1", { + headers = { Host = "test3.com" }, }) local body = assert.res_status(429, res) assert.equal([[]], body) @@ -387,14 +404,10 @@ for _, strategy in helpers.each_strategy() do assert.equal(2, tonumber(res.headers["x-ratelimit-limit-video-minute"])) end) - it("blocks if exceeding limit and a per consumer setting", function() + it("blocks if exceeding limit and a per consumer & route setting", function() for i = 1, 6 do - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/response-headers?apikey=apikey124&x-kong-limit=video=1", - headers = { - ["Host"] = "test3.com" - } + local res = proxy_client():get("/response-headers?apikey=apikey124&x-kong-limit=video=1", { + headers = { Host = "test3.com" }, }) ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit @@ -404,13 +417,10 @@ for _, strategy in helpers.each_strategy() do assert.equal(6 - i, tonumber(res.headers["x-ratelimit-remaining-video-minute"])) end - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/response-headers?apikey=apikey124", - headers = { - ["Host"] = "test3.com" - } + local res = proxy_client():get("/response-headers?apikey=apikey124", { + headers = { Host = "test3.com" }, }) + assert.res_status(200, res) assert.equal(0, tonumber(res.headers["x-ratelimit-remaining-video-minute"])) assert.equal(6, tonumber(res.headers["x-ratelimit-limit-video-minute"])) @@ -418,12 +428,8 @@ for _, strategy in helpers.each_strategy() do it("blocks if exceeding limit", function() for i = 1, 6 do - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/response-headers?apikey=apikey125&x-kong-limit=video=1", - headers = { - ["Host"] = "test3.com" - } + local res = proxy_client():get("/response-headers?apikey=apikey125&x-kong-limit=video=1", { + headers = { Host = "test3.com" }, }) ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit @@ -434,12 +440,8 @@ for _, strategy in helpers.each_strategy() do end -- Third query, while limit is 2/minute - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/response-headers?apikey=apikey125&x-kong-limit=video=1", - headers = { - ["Host"] = "test3.com" - } + local res = proxy_client():get("/response-headers?apikey=apikey125&x-kong-limit=video=1", { + headers = { Host = "test3.com" }, }) local body = assert.res_status(429, res) assert.equal([[]], body) @@ -451,35 +453,23 @@ for _, strategy in helpers.each_strategy() do describe("Upstream usage headers", function() it("should append the headers with multiple limits", function() - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/get", - headers = { - ["Host"] = "test8.com" - } + local res = proxy_client():get("/get", { + headers = { Host = "test8.com" }, }) local json = cjson.decode(assert.res_status(200, res)) assert.equal(4, tonumber(json.headers["x-ratelimit-remaining-image"])) assert.equal(6, tonumber(json.headers["x-ratelimit-remaining-video"])) -- Actually consume the limits - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/response-headers?x-kong-limit=video=2, image=1", - headers = { - ["Host"] = "test8.com" - } + local res = proxy_client():get("/response-headers?x-kong-limit=video=2, image=1", { + headers = { Host = "test8.com" }, }) assert.res_status(200, res) ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/get", - headers = { - ["Host"] = "test8.com" - } + local res = proxy_client():get("/get", { + headers = { Host = "test8.com" }, }) local body = cjson.decode(assert.res_status(200, res)) assert.equal(3, tonumber(body.headers["x-ratelimit-remaining-image"])) @@ -488,12 +478,8 @@ for _, strategy in helpers.each_strategy() do it("combines multiple x-kong-limit headers from upstream", function() for i = 1, 3 do - local res = assert(proxy_client:send { - method = "GET", - path = "/response-headers?x-kong-limit=video%3D2&x-kong-limit=image%3D1", - headers = { - ["Host"] = "test4.com" - } + local res = proxy_client():get("/response-headers?x-kong-limit=video%3D2&x-kong-limit=image%3D1", { + headers = { Host = "test4.com" }, }) assert.res_status(200, res) @@ -505,12 +491,8 @@ for _, strategy in helpers.each_strategy() do ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit end - local res = assert(proxy_client:send { - method = "GET", - path = "/response-headers?x-kong-limit=video%3D2&x-kong-limit=image%3D1", - headers = { - ["Host"] = "test4.com" - } + local res = proxy_client():get("/response-headers?x-kong-limit=video%3D2&x-kong-limit=image%3D1", { + headers = { Host = "test4.com" }, }) local body = assert.res_status(429, res) @@ -521,23 +503,15 @@ for _, strategy in helpers.each_strategy() do end) it("should block on first violation", function() - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/response-headers?x-kong-limit=video=2, image=4", - headers = { - ["Host"] = "test7.com" - } + local res = proxy_client():get("/response-headers?x-kong-limit=video=2, image=4", { + headers = { Host = "test7.com" }, }) assert.res_status(200, res) ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/response-headers?x-kong-limit=video=2", - headers = { - ["Host"] = "test7.com" - } + local res = proxy_client():get("/response-headers?x-kong-limit=video=2", { + headers = { Host = "test7.com" }, }) local body = assert.res_status(429, res) local json = cjson.decode(body) @@ -546,12 +520,8 @@ for _, strategy in helpers.each_strategy() do describe("Config with hide_client_headers", function() it("does not send rate-limit headers when hide_client_headers==true", function() - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200", - headers = { - ["Host"] = "test9.com" - } + local res = proxy_client():get("/status/200", { + headers = { Host = "test9.com" }, }) assert.res_status(200, res) @@ -572,8 +542,7 @@ for _, strategy in helpers.each_strategy() do hosts = { "failtest1.com" }, } - bp.plugins:insert { - name = "response-ratelimiting", + bp.response_ratelimiting_plugins:insert { route_id = route1.id, config = { fault_tolerant = false, @@ -589,8 +558,7 @@ for _, strategy in helpers.each_strategy() do hosts = { "failtest2.com" }, } - bp.plugins:insert { - name = "response-ratelimiting", + bp.response_ratelimiting_plugins:insert { route_id = route2.id, config = { fault_tolerant = true, @@ -615,12 +583,8 @@ for _, strategy in helpers.each_strategy() do end) it("does not work if an error occurs", function() - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/response-headers?x-kong-limit=video=1", - headers = { - ["Host"] = "failtest1.com" - } + local res = proxy_client():get("/response-headers?x-kong-limit=video=1", { + headers = { Host = "failtest1.com" }, }) assert.res_status(200, res) assert.equal(6, tonumber(res.headers["x-ratelimit-limit-video-minute"])) @@ -630,12 +594,8 @@ for _, strategy in helpers.each_strategy() do assert(dao.db:drop_table("response_ratelimiting_metrics")) -- Make another request - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/response-headers?x-kong-limit=video=1", - headers = { - ["Host"] = "failtest1.com" - } + local res = proxy_client():get("/response-headers?x-kong-limit=video=1", { + headers = { Host = "failtest1.com" }, }) local body = assert.res_status(500, res) local json = cjson.decode(body) @@ -643,12 +603,8 @@ for _, strategy in helpers.each_strategy() do end) it("keeps working if an error occurs", function() - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/response-headers?x-kong-limit=video=1", - headers = { - ["Host"] = "failtest2.com" - } + local res = proxy_client():get("/response-headers?x-kong-limit=video=1", { + headers = { Host = "failtest2.com" }, }) assert.res_status(200, res) assert.equal(6, tonumber(res.headers["x-ratelimit-limit-video-minute"])) @@ -658,12 +614,8 @@ for _, strategy in helpers.each_strategy() do assert(dao.db:drop_table("response_ratelimiting_metrics")) -- Make another request - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/response-headers?x-kong-limit=video=1", - headers = { - ["Host"] = "failtest2.com" - } + local res = proxy_client():get("/response-headers?x-kong-limit=video=1", { + headers = { Host = "failtest2.com" }, }) assert.res_status(200, res) assert.is_nil(res.headers["x-ratelimit-limit-video-minute"]) @@ -681,14 +633,13 @@ for _, strategy in helpers.each_strategy() do local service1 = bp.services:insert() - local route1 = assert(db.routes:insert { + local route1 = bp.routes:insert { hosts = { "failtest3.com" }, protocols = { "http", "https" }, service = service1 - }) + } - assert(dao.plugins:insert { - name = "response-ratelimiting", + bp.response_ratelimiting_plugins:insert { route_id = route1.id, config = { fault_tolerant = false, @@ -696,18 +647,17 @@ for _, strategy in helpers.each_strategy() do redis_host = "5.5.5.5", limits = { video = { minute = 6 } }, } - }) + } local service2 = bp.services:insert() - local route2 = assert(db.routes:insert { + local route2 = bp.routes:insert { hosts = { "failtest4.com" }, protocols = { "http", "https" }, service = service2 - }) + } - assert(dao.plugins:insert { - name = "response-ratelimiting", + bp.response_ratelimiting_plugins:insert { route_id = route2.id, config = { fault_tolerant = true, @@ -715,7 +665,7 @@ for _, strategy in helpers.each_strategy() do redis_host = "5.5.5.5", limits = { video = { minute = 6 } }, } - }) + } assert(helpers.start_kong({ database = strategy, @@ -725,12 +675,8 @@ for _, strategy in helpers.each_strategy() do it("does not work if an error occurs", function() -- Make another request - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200", - headers = { - ["Host"] = "failtest3.com" - } + local res = proxy_client():get("/status/200", { + headers = { Host = "failtest3.com" }, }) local body = assert.res_status(500, res) local json = cjson.decode(body) @@ -738,12 +684,8 @@ for _, strategy in helpers.each_strategy() do end) it("keeps working if an error occurs", function() -- Make another request - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/status/200", - headers = { - ["Host"] = "failtest4.com" - } + local res = proxy_client():get("/status/200", { + headers = { Host = "failtest4.com" }, }) assert.res_status(200, res) assert.falsy(res.headers["x-ratelimit-limit-video-minute"]) @@ -760,14 +702,13 @@ for _, strategy in helpers.each_strategy() do local service = bp.services:insert() - local route = assert(db.routes:insert { + local route = bp.routes:insert { hosts = { "expire1.com" }, protocols = { "http", "https" }, service = service - }) + } - assert(dao.plugins:insert { - name = "response-ratelimiting", + bp.response_ratelimiting_plugins:insert { route_id = route.id, config = { policy = policy, @@ -777,7 +718,7 @@ for _, strategy in helpers.each_strategy() do fault_tolerant = false, limits = { video = { minute = 6 } }, } - }) + } assert(helpers.start_kong({ database = strategy, @@ -786,12 +727,8 @@ for _, strategy in helpers.each_strategy() do end) it("expires a counter", function() - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/response-headers?x-kong-limit=video=1", - headers = { - ["Host"] = "expire1.com" - } + local res = proxy_client():get("/response-headers?x-kong-limit=video=1", { + headers = { Host = "expire1.com" }, }) ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit @@ -802,12 +739,8 @@ for _, strategy in helpers.each_strategy() do ngx.sleep(61) -- Wait for counter to expire - local res = assert(helpers.proxy_client():send { - method = "GET", - path = "/response-headers?x-kong-limit=video=1", - headers = { - ["Host"] = "expire1.com" - } + local res = proxy_client():get("/response-headers?x-kong-limit=video=1", { + headers = { Host = "expire1.com" }, }) ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit @@ -818,5 +751,150 @@ for _, strategy in helpers.each_strategy() do end) end) end) + + describe(fmt("#flaky Plugin: response-rate-limiting (access - global for single consumer) with policy: %s [#%s]", policy, strategy), function() + local bp + local db + local dao + setup(function() + helpers.kill_all() + flush_redis() + bp, db, dao = helpers.get_db_utils(strategy) + assert(db:truncate()) + dao:truncate_tables() + assert(dao:run_migrations()) + + local consumer = bp.consumers:insert { + custom_id = "provider_125", + } + + bp.key_auth_plugins:insert() + + bp.keyauth_credentials:insert { + key = "apikey125", + consumer_id = consumer.id, + } + + -- just consumer, no no route or service + bp.response_ratelimiting_plugins:insert({ + consumer_id = consumer.id, + config = { + fault_tolerant = false, + policy = policy, + redis_host = REDIS_HOST, + redis_port = REDIS_PORT, + redis_password = REDIS_PASSWORD, + redis_database = REDIS_DATABASE, + limits = { video = { minute = 6 } }, + } + }) + + for i = 1, 6 do + bp.routes:insert({ hosts = { fmt("test%d.com", i) } }) + end + + assert(helpers.start_kong({ + database = strategy, + nginx_conf = "spec/fixtures/custom_nginx.template", + })) + end) + + teardown(function() + helpers.kill_all() + dao:drop_schema() + assert(db:truncate()) + assert(dao:run_migrations()) + end) + + it("blocks when the consumer exceeds their quota, no matter what service/route used", function() + for i = 1, 6 do + local res = proxy_client():get("/response-headers?apikey=apikey125&x-kong-limit=video=1", { + headers = { Host = fmt("test%d.com", i) }, + }) + + ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit + + assert.res_status(200, res) + assert.equal(6, tonumber(res.headers["x-ratelimit-limit-video-minute"])) + assert.equal(6 - i, tonumber(res.headers["x-ratelimit-remaining-video-minute"])) + end + + -- 7th query, while limit is 6/minute + local res = proxy_client():get("/response-headers?apikey=apikey125&x-kong-limit=video=1", { + headers = { Host = "test1.com" }, + }) + local body = assert.res_status(429, res) + assert.equal([[]], body) + assert.equal(0, tonumber(res.headers["x-ratelimit-remaining-video-minute"])) + assert.equal(6, tonumber(res.headers["x-ratelimit-limit-video-minute"])) + end) + end) + + describe(fmt("#flaky Plugin: rate-limiting (access - global) with policy: %s [#%s]", policy, strategy), function() + local bp + local db + local dao + setup(function() + helpers.kill_all() + flush_redis() + bp, db, dao = helpers.get_db_utils(strategy) + assert(db:truncate()) + dao:truncate_tables() + assert(dao:run_migrations()) + + -- global plugin (not attached to route, service or consumer) + bp.response_ratelimiting_plugins:insert({ + config = { + fault_tolerant = false, + policy = policy, + redis_host = REDIS_HOST, + redis_port = REDIS_PORT, + redis_password = REDIS_PASSWORD, + redis_database = REDIS_DATABASE, + limits = { video = { minute = 6 } }, + } + }) + + for i = 1, 6 do + bp.routes:insert({ hosts = { fmt("test%d.com", i) } }) + end + + assert(helpers.start_kong({ + database = strategy, + nginx_conf = "spec/fixtures/custom_nginx.template", + })) + end) + + teardown(function() + helpers.kill_all() + dao:drop_schema() + assert(db:truncate()) + assert(dao:run_migrations()) + end) + + it("blocks if exceeding limit", function() + for i = 1, 6 do + local res = proxy_client():get("/response-headers?x-kong-limit=video=1", { + headers = { Host = fmt("test%d.com", i) }, + }) + + ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit + + assert.res_status(200, res) + assert.equal(6, tonumber(res.headers["x-ratelimit-limit-video-minute"])) + assert.equal(6 - i, tonumber(res.headers["x-ratelimit-remaining-video-minute"])) + end + + -- 7th query, while limit is 6/minute + local res = proxy_client():get("/response-headers?x-kong-limit=video=1", { + headers = { Host = "test1.com" }, + }) + local body = assert.res_status(429, res) + assert.equal([[]], body) + assert.equal(0, tonumber(res.headers["x-ratelimit-remaining-video-minute"])) + assert.equal(6, tonumber(res.headers["x-ratelimit-limit-video-minute"])) + end) + end) + end end diff --git a/spec/03-plugins/26-oauth2/01-schema_spec.lua b/spec/03-plugins/26-oauth2/01-schema_spec.lua index d24cb578a3b..e1b43f33590 100644 --- a/spec/03-plugins/26-oauth2/01-schema_spec.lua +++ b/spec/03-plugins/26-oauth2/01-schema_spec.lua @@ -1,70 +1,195 @@ +local helpers = require "spec.helpers" local validate_entity = require("kong.dao.schemas_validation").validate_entity +local oauth2_daos = require "kong.plugins.oauth2.daos" +local utils = require "kong.tools.utils" + local oauth2_schema = require "kong.plugins.oauth2.schema" +local oauth2_authorization_codes_schema = oauth2_daos.oauth2_authorization_codes +local oauth2_tokens_schema = oauth2_daos.oauth2_tokens -pending("Plugin: oauth2 (schema)", function() - it("does not require `scopes` when `mandatory_scope` is false", function() - local ok, errors = validate_entity({enable_authorization_code = true, mandatory_scope = false}, oauth2_schema) - assert.True(ok) - assert.is_nil(errors) - end) - it("valid when both `scopes` when `mandatory_scope` are given", function() - local ok, errors = validate_entity({enable_authorization_code = true, mandatory_scope = true, scopes = {"email", "info"}}, oauth2_schema) - assert.True(ok) - assert.is_nil(errors) - end) - it("autogenerates `provision_key` when not given", function() - local t = {enable_authorization_code = true, mandatory_scope = true, scopes = {"email", "info"}} - local ok, errors = validate_entity(t, oauth2_schema) - assert.True(ok) - assert.is_nil(errors) - assert.truthy(t.provision_key) - assert.equal(32, t.provision_key:len()) - end) - it("does not autogenerate `provision_key` when it is given", function() - local t = {enable_authorization_code = true, mandatory_scope = true, scopes = {"email", "info"}, provision_key = "hello"} - local ok, errors = validate_entity(t, oauth2_schema) - assert.True(ok) - assert.is_nil(errors) - assert.truthy(t.provision_key) - assert.equal("hello", t.provision_key) - end) - it("sets default `auth_header_name` when not given", function() - local t = {enable_authorization_code = true, mandatory_scope = true, scopes = {"email", "info"}} - local ok, errors = validate_entity(t, oauth2_schema) - assert.True(ok) - assert.is_nil(errors) - assert.truthy(t.provision_key) - assert.equal(32, t.provision_key:len()) - assert.equal("authorization", t.auth_header_name) - end) - it("does not set default value for `auth_header_name` when it is given", function() - local t = {enable_authorization_code = true, mandatory_scope = true, scopes = {"email", "info"}, provision_key = "hello", - auth_header_name="custom_header_name"} - local ok, errors = validate_entity(t, oauth2_schema) - assert.True(ok) - assert.is_nil(errors) - assert.truthy(t.provision_key) - assert.equal("hello", t.provision_key) - assert.equal("custom_header_name", t.auth_header_name) - end) - it("sets refresh_token_ttl to default value if not set", function() - local t = {enable_authorization_code = true, mandatory_scope = false} - local ok, errors = validate_entity(t, oauth2_schema) - assert.True(ok) - assert.is_nil(errors) - assert.equal(1209600, t.refresh_token_ttl) - end) - describe("errors", function() - it("requires at least one flow", function() - local ok, _, err = validate_entity({}, oauth2_schema) - assert.False(ok) - assert.equal("You need to enable at least one OAuth flow", tostring(err)) + +local fmt = string.format + + +for _, strategy in helpers.each_strategy() do + describe(fmt("Plugin: oauth2 [#%s] (schema)", strategy), function() + local bp, db, dao = helpers.get_db_utils(strategy) + + it("does not require `scopes` when `mandatory_scope` is false", function() + local ok, errors = validate_entity({enable_authorization_code = true, mandatory_scope = false}, oauth2_schema) + assert.True(ok) + assert.is_nil(errors) + end) + it("valid when both `scopes` when `mandatory_scope` are given", function() + local ok, errors = validate_entity({enable_authorization_code = true, mandatory_scope = true, scopes = {"email", "info"}}, oauth2_schema) + assert.True(ok) + assert.is_nil(errors) + end) + it("autogenerates `provision_key` when not given", function() + local t = {enable_authorization_code = true, mandatory_scope = true, scopes = {"email", "info"}} + local ok, errors = validate_entity(t, oauth2_schema) + assert.True(ok) + assert.is_nil(errors) + assert.truthy(t.provision_key) + assert.equal(32, t.provision_key:len()) + end) + it("does not autogenerate `provision_key` when it is given", function() + local t = {enable_authorization_code = true, mandatory_scope = true, scopes = {"email", "info"}, provision_key = "hello"} + local ok, errors = validate_entity(t, oauth2_schema) + assert.True(ok) + assert.is_nil(errors) + assert.truthy(t.provision_key) + assert.equal("hello", t.provision_key) end) - it("requires `scopes` when `mandatory_scope` is true", function() - local ok, errors = validate_entity({enable_authorization_code = true, mandatory_scope = true}, oauth2_schema) - assert.False(ok) - assert.equal("To set a mandatory scope you also need to create available scopes", errors.mandatory_scope) + it("sets default `auth_header_name` when not given", function() + local t = {enable_authorization_code = true, mandatory_scope = true, scopes = {"email", "info"}} + local ok, errors = validate_entity(t, oauth2_schema) + assert.True(ok) + assert.is_nil(errors) + assert.truthy(t.provision_key) + assert.equal(32, t.provision_key:len()) + assert.equal("authorization", t.auth_header_name) + end) + it("does not set default value for `auth_header_name` when it is given", function() + local t = {enable_authorization_code = true, mandatory_scope = true, scopes = {"email", "info"}, provision_key = "hello", + auth_header_name="custom_header_name"} + local ok, errors = validate_entity(t, oauth2_schema) + assert.True(ok) + assert.is_nil(errors) + assert.truthy(t.provision_key) + assert.equal("hello", t.provision_key) + assert.equal("custom_header_name", t.auth_header_name) + end) + it("sets refresh_token_ttl to default value if not set", function() + local t = {enable_authorization_code = true, mandatory_scope = false} + local ok, errors = validate_entity(t, oauth2_schema) + assert.True(ok) + assert.is_nil(errors) + assert.equal(1209600, t.refresh_token_ttl) + end) + + describe("errors", function() + it("requires at least one flow", function() + local ok, _, err = validate_entity({}, oauth2_schema) + assert.False(ok) + assert.equal("You need to enable at least one OAuth flow", tostring(err)) + end) + it("requires `scopes` when `mandatory_scope` is true", function() + local ok, errors = validate_entity({enable_authorization_code = true, mandatory_scope = true}, oauth2_schema) + assert.False(ok) + assert.equal("To set a mandatory scope you also need to create available scopes", errors.mandatory_scope) + end) + it("errors when given an invalid service_id on oauth tokens", function() + local service = bp.services:insert() + local u = utils.uuid() + + local ok, err, err_t = validate_entity({ + credential_id = "foo", expires_in = 1, + service_id = "bar", + }, oauth2_tokens_schema, { dao = dao }) + assert.False(ok) + assert.is_nil(err) + assert.equals(err_t.tbl.fields.id, "expected a valid UUID") + + local ok, err, err_t = validate_entity({ + credential_id = "foo", expires_in = 1, + service_id = u, + }, oauth2_tokens_schema, { dao = dao }) + assert.False(ok) + assert.is_nil(err) + assert.equals(err_t.message, fmt("no such Service (id=%s)", u)) + + local ok, err, err_t = validate_entity({ + credential_id = "foo", expires_in = 1, + service_id = service.id, + }, oauth2_tokens_schema, { dao = dao }) + + assert.True(ok) + assert.is_nil(err) + assert.is_nil(err_t) + end) + + it("errors when given an invalid service_id on oauth authorization codes", function() + local service = bp.services:insert() + local u = utils.uuid() + + local ok, err, err_t = validate_entity({ + credential_id = "foo", + service_id = "bar", + }, oauth2_authorization_codes_schema, { dao = dao }) + assert.False(ok) + assert.is_nil(err) + assert.equals(err_t.tbl.fields.id, "expected a valid UUID") + + local ok, err, err_t = validate_entity({ + credential_id = "foo", + service_id = u, + }, oauth2_authorization_codes_schema, { dao = dao }) + assert.False(ok) + assert.is_nil(err) + assert.equals(err_t.message, fmt("no such Service (id=%s)", u)) + + local ok, err, err_t = validate_entity({ + credential_id = "foo", + service_id = service.id, + }, oauth2_authorization_codes_schema, { dao = dao }) + + assert.True(ok) + assert.is_nil(err) + assert.is_nil(err_t) + end) + end) + + describe("when deleting a service", function() + it("deletes associated oauth2 entities", function() + local service = bp.services:insert() + local consumer = bp.consumers:insert() + local credential = bp.oauth2_credentials:insert({ + redirect_uri = "http://example.com", + consumer_id = consumer.id, + }) + + local ok, err, err_t + + local token = bp.oauth2_tokens:insert({ + credential_id = credential.id, + service_id = service.id, + }) + local code = bp.oauth2_authorization_codes:insert({ + credential_id = credential.id, + service_id = service.id, + }) + + token, err = dao.oauth2_tokens:find(token) + assert.falsy(err) + assert.truthy(token) + + code, err = dao.oauth2_authorization_codes:find(code) + assert.falsy(err) + assert.truthy(code) + + + ok, err, err_t = db.services:delete({ id = service.id }) + assert.truthy(ok) + assert.is_nil(err_t) + assert.is_nil(err) + + -- no more service + service, err = db.services:select({ id = service.id }) + assert.falsy(err) + assert.falsy(service) + + -- no more token + token, err = dao.oauth2_tokens:find({ id = token.id }) + assert.falsy(err) + assert.falsy(token) + + -- no more code + local code, err = dao.oauth2_authorization_codes:find({ id = code.id }) + assert.falsy(err) + assert.falsy(code) + end) end) end) -end) +end diff --git a/spec/03-plugins/26-oauth2/02-api_spec.lua b/spec/03-plugins/26-oauth2/02-api_spec.lua index dce6d000e23..01a8c482570 100644 --- a/spec/03-plugins/26-oauth2/02-api_spec.lua +++ b/spec/03-plugins/26-oauth2/02-api_spec.lua @@ -3,21 +3,20 @@ local helpers = require "spec.helpers" for _, strategy in helpers.each_strategy() do - pending("Plugin: oauth (API) [#" .. strategy .. "]", function() + describe("Plugin: oauth (API) [#" .. strategy .. "]", function() local consumer - local route + local service local admin_client local db local dao local bp setup(function() - db, dao, bp = helpers.get_db_utils(strategy) + bp, db, dao = helpers.get_db_utils(strategy) assert(db:truncate()) dao:truncate_tables() - - helpers.run_migrations(dao) + assert(dao:run_migrations()) helpers.prepare_prefix() @@ -36,17 +35,9 @@ for _, strategy in helpers.each_strategy() do describe("/consumers/:consumer/oauth2/", function() setup(function() - route = bp.routes:insert({ - hosts = { "oauth2_token.com" }, - }) - - consumer = assert(dao.consumers:insert { - username = "bob" - }) - - assert(dao.consumers:insert { - username = "sally" - }) + service = bp.services:insert({ host = "oauth2_token.com" }) + consumer = bp.consumers:insert({ username = "bob" }) + bp.consumers:insert({ username = "sally" }) end) after_each(function() @@ -260,6 +251,11 @@ for _, strategy in helpers.each_strategy() do local credential before_each(function() dao:truncate_table("oauth2_credentials") + dao:truncate_table("consumers") + db:truncate("services") + + service = bp.services:insert({ host = "oauth2_token.com" }) + consumer = bp.consumers:insert({ username = "bob" }) credential = assert(dao.oauth2_credentials:insert { name = "test app", redirect_uri = helpers.mock_upstream_ssl_url, @@ -418,7 +414,7 @@ for _, strategy in helpers.each_strategy() do path = "/oauth2_tokens", body = { credential_id = oauth2_credential.id, - route_id = route.id, + service_id = service.id, expires_in = 10 }, headers = { @@ -429,7 +425,7 @@ for _, strategy in helpers.each_strategy() do assert.equal(oauth2_credential.id, body.credential_id) assert.equal(10, body.expires_in) assert.truthy(body.access_token) - assert.truthy(body.route_id) + assert.truthy(body.service_id) assert.falsy(body.refresh_token) assert.equal("bearer", body.token_type) end) @@ -457,7 +453,7 @@ for _, strategy in helpers.each_strategy() do path = "/oauth2_tokens", body = { credential_id = oauth2_credential.id, - route_id = route.id, + service_id = service.id, expires_in = 10 }, headers = { @@ -493,7 +489,7 @@ for _, strategy in helpers.each_strategy() do for _ = 1, 3 do assert(dao.oauth2_tokens:insert { credential_id = oauth2_credential.id, - route_id = route.id, + service_id = service.id, expires_in = 10 }) end @@ -520,7 +516,7 @@ for _, strategy in helpers.each_strategy() do dao:truncate_table("oauth2_tokens") token = assert(dao.oauth2_tokens:insert { credential_id = oauth2_credential.id, - route_id = route.id, + service_id = service.id, expires_in = 10 }) end) diff --git a/spec/03-plugins/26-oauth2/03-access_spec.lua b/spec/03-plugins/26-oauth2/03-access_spec.lua index 6cc942de949..5fe88261b2c 100644 --- a/spec/03-plugins/26-oauth2/03-access_spec.lua +++ b/spec/03-plugins/26-oauth2/03-access_spec.lua @@ -3,6 +3,9 @@ local helpers = require "spec.helpers" local utils = require "kong.tools.utils" +local fmt = string.format + + local function provision_code(host, extra_headers, client_id) local request_client = helpers.proxy_ssl_client() local res = assert(request_client:send { @@ -59,7 +62,7 @@ end for _, strategy in helpers.each_strategy() do - pending("Plugin: oauth2 (access) [#" .. strategy .. "]", function() + describe("Plugin: oauth2 (access) [#" .. strategy .. "]", function() local proxy_ssl_client local proxy_client local client1 @@ -68,12 +71,7 @@ for _, strategy in helpers.each_strategy() do local bp setup(function() - db, dao, bp = helpers.get_db_utils(strategy) - - assert(db:truncate()) - dao:truncate_tables() - - helpers.run_migrations(dao) + bp, db, dao = helpers.get_db_utils(strategy) local consumer = assert(dao.consumers:insert { username = "bob" @@ -115,6 +113,14 @@ for _, strategy in helpers.each_strategy() do consumer_id = consumer.id }) + assert(dao.oauth2_credentials:insert { + client_id = "clientid1011", + client_secret = "secret1011", + redirect_uri = { "http://google.com/kong", }, + name = "testapp31", + consumer_id = consumer.id + }) + local service1 = bp.services:insert() local service2 = bp.services:insert() local service2bis = bp.services:insert() @@ -126,6 +132,8 @@ for _, strategy in helpers.each_strategy() do local service8 = bp.services:insert() local service9 = bp.services:insert() local service10 = bp.services:insert() + local service11 = bp.services:insert() + local service12 = bp.services:insert() local route1 = assert(db.routes:insert({ hosts = { "oauth2.com" }, @@ -193,6 +201,18 @@ for _, strategy in helpers.each_strategy() do service = service10, })) + local route11 = assert(db.routes:insert({ + hosts = { "oauth2_11.com" }, + protocols = { "http", "https" }, + service = service11, + })) + + local route12 = assert(db.routes:insert({ + hosts = { "oauth2_12.com" }, + protocols = { "http", "https" }, + service = service12, + })) + bp.oauth2_plugins:insert({ route_id = route1.id, config = { scopes = { "email", "profile", "user.email" } }, @@ -270,6 +290,26 @@ for _, strategy in helpers.each_strategy() do }, }) + bp.oauth2_plugins:insert({ + route_id = route11.id, + config = { + scopes = { "email", "profile", "user.email" }, + global_credentials = true, + token_expiration = 7, + auth_header_name = "custom_header_name", + }, + }) + + bp.oauth2_plugins:insert({ + route_id = route12.id, + config = { + scopes = { "email", "profile", "user.email" }, + global_credentials = true, + auth_header_name = "custom_header_name", + hide_credentials = true, + }, + }) + assert(helpers.start_kong({ database = strategy, trusted_ips = "127.0.0.1", @@ -747,7 +787,6 @@ for _, strategy in helpers.each_strategy() do local body = cjson.decode(assert.res_status(200, res)) assert.is_table(ngx.re.match(body.redirect_uri, "^http://google\\.com/kong\\#access_token=[\\w]{32,32}&expires_in=[\\d]+&state=wot&token_type=bearer$")) end) - it("returns success and the token should have the right expiration", function() local res = assert(proxy_ssl_client:send { method = "POST", @@ -1842,6 +1881,35 @@ for _, strategy in helpers.each_strategy() do }) assert.res_status(500, res) end) + it("returns success and the token should have the right expiration when a custom header is passed", function() + local res = assert(proxy_ssl_client:send { + method = "POST", + path = "/oauth2/authorize", + body = { + provision_key = "provision123", + authenticated_userid = "id123", + client_id = "clientid1011", + scope = "email", + response_type = "token" + }, + headers = { + ["Host"] = "oauth2_11.com", + ["Content-Type"] = "application/json" + } + }) + local body = cjson.decode(assert.res_status(200, res)) + assert.is_table(ngx.re.match(body.redirect_uri, "^http://google\\.com/kong\\#access_token=[\\w]{32,32}&expires_in=[\\d]+&token_type=bearer$")) + + local iterator, err = ngx.re.gmatch(body.redirect_uri, "^http://google\\.com/kong\\#access_token=([\\w]{32,32})&expires_in=[\\d]+&token_type=bearer$") + assert.is_nil(err) + local m, err = iterator() + assert.is_nil(err) + local data = dao.oauth2_tokens:find_all {access_token = m[1]} + assert.are.equal(1, #data) + assert.are.equal(m[1], data[1].access_token) + assert.are.equal(7, data[1].expires_in) + assert.falsy(data[1].refresh_token) + end) describe("Global Credentials", function() it("does not access two different APIs that are not sharing global credentials", function() local token = provision_token("oauth2_8.com") @@ -2234,7 +2302,7 @@ for _, strategy in helpers.each_strategy() do end) - pending("Plugin: oauth2 (access) [#" .. strategy .. "]", function() + describe("Plugin: oauth2 (access) [#" .. strategy .. "]", function() local proxy_client local user1 local user2 @@ -2244,12 +2312,7 @@ for _, strategy in helpers.each_strategy() do local bp setup(function() - db, dao, bp = helpers.get_db_utils(strategy) - - assert(db:truncate()) - dao:truncate_tables() - - helpers.run_migrations(dao) + bp, db, dao = helpers.get_db_utils(strategy) local service1 = bp.services:insert({ path = "/request" @@ -2466,22 +2529,14 @@ for _, strategy in helpers.each_strategy() do end) end) end) -end - -for _, strategy in helpers.each_strategy() do - describe("Plugin: oauth2 (ttl) [#" .. strategy .. "]", function() + describe("Plugin: oauth2 (ttl) with #"..strategy, function() local db local dao local bp setup(function() - db, dao, bp = helpers.get_db_utils(strategy) - - assert(db:truncate()) - dao:truncate_tables() - - helpers.run_migrations(dao) + bp, db, dao = helpers.get_db_utils(strategy) local route11 = assert(db.routes:insert({ hosts = { "oauth2_11.com" }, @@ -2489,10 +2544,13 @@ for _, strategy in helpers.each_strategy() do service = bp.services:insert(), })) - bp.oauth2_plugins:insert({ + assert(dao.plugins:insert { + name = "oauth2", route_id = route11.id, config = { + enable_authorization_code = true, mandatory_scope = false, + provision_key = "provision123", anonymous = "", global_credentials = false, refresh_token_ttl = 2 @@ -2505,10 +2563,13 @@ for _, strategy in helpers.each_strategy() do service = bp.services:insert(), })) - bp.oauth2_plugins:insert({ + assert(dao.plugins:insert { + name = "oauth2", route_id = route12.id, config = { + enable_authorization_code = true, mandatory_scope = false, + provision_key = "provision123", anonymous = "", global_credentials = false, refresh_token_ttl = 0 @@ -2518,7 +2579,6 @@ for _, strategy in helpers.each_strategy() do local consumer = assert(dao.consumers:insert { username = "bob" }) - assert(dao.oauth2_credentials:insert { client_id = "clientid123", client_secret = "secret123", @@ -2526,8 +2586,11 @@ for _, strategy in helpers.each_strategy() do name = "testapp", consumer_id = consumer.id }) - - assert(helpers.start_kong()) + assert(helpers.start_kong({ + database = strategy, + trusted_ips = "127.0.0.1", + nginx_conf = "spec/fixtures/custom_nginx.template", + })) end) teardown(function() @@ -2536,8 +2599,8 @@ for _, strategy in helpers.each_strategy() do local function assert_ttls_records_for_token(uuid, count) local DB = require "kong.dao.db.postgres" - local _db = DB.new(strategy) - local query = string.format("SELECT COUNT(*) FROM ttls where table_name='oauth2_tokens' AND primary_uuid_value = '%s'", tostring(uuid)) + local _db = DB.new(helpers.test_conf, strategy) + local query = fmt("SELECT COUNT(*) FROM ttls where table_name='oauth2_tokens' AND primary_uuid_value = '%s'", tostring(uuid)) local result, error = _db:query(query) assert.falsy(error) assert.truthy(result[1].count == count) @@ -2546,7 +2609,7 @@ for _, strategy in helpers.each_strategy() do describe("refresh token", function() it("is deleted after defined TTL", function() local token = provision_token("oauth2_11.com") - local token_entity = helpers.dao.oauth2_tokens:find_all { access_token = token.access_token } + local token_entity = dao.oauth2_tokens:find_all { access_token = token.access_token } assert.equal(1, #token_entity) if strategy == "postgres" then @@ -2555,13 +2618,13 @@ for _, strategy in helpers.each_strategy() do ngx.sleep(3) - token_entity = helpers.dao.oauth2_tokens:find_all { access_token = token.access_token } + token_entity = dao.oauth2_tokens:find_all { access_token = token.access_token } assert.equal(0, #token_entity) end) it("is not deleted when when TTL is 0 == never", function() local token = provision_token("oauth2_12.com") - local token_entity = helpers.dao.oauth2_tokens:find_all { access_token = token.access_token } + local token_entity = dao.oauth2_tokens:find_all { access_token = token.access_token } assert.equal(1, #token_entity) if strategy == "postgres" then @@ -2570,11 +2633,9 @@ for _, strategy in helpers.each_strategy() do ngx.sleep(3) - token_entity = helpers.dao.oauth2_tokens:find_all { access_token = token.access_token } + token_entity = dao.oauth2_tokens:find_all { access_token = token.access_token } assert.equal(1, #token_entity) end) end) - end) - end diff --git a/spec/03-plugins/26-oauth2/04-invalidations_spec.lua b/spec/03-plugins/26-oauth2/04-invalidations_spec.lua index 115c8e757af..ff95517591a 100644 --- a/spec/03-plugins/26-oauth2/04-invalidations_spec.lua +++ b/spec/03-plugins/26-oauth2/04-invalidations_spec.lua @@ -3,7 +3,7 @@ local helpers = require "spec.helpers" for _, strategy in helpers.each_strategy() do - pending("Plugin: oauth2 (invalidations) [#" .. strategy .. "]", function() + describe("Plugin: oauth2 (invalidations) [#" .. strategy .. "]", function() local admin_client local proxy_ssl_client local db @@ -11,12 +11,11 @@ for _, strategy in helpers.each_strategy() do local bp setup(function() - db, dao, bp = helpers.get_db_utils(strategy) + bp, db, dao = helpers.get_db_utils(strategy) assert(db:truncate()) dao:truncate_tables() - - helpers.run_migrations(dao) + assert(dao:run_migrations()) end) before_each(function() diff --git a/spec/fixtures/blueprints.lua b/spec/fixtures/blueprints.lua index 8a73d27110c..726bd791986 100644 --- a/spec/fixtures/blueprints.lua +++ b/spec/fixtures/blueprints.lua @@ -5,16 +5,6 @@ local deep_merge = utils.deep_merge local fmt = string.format -local function shuffle(tbl) - local size = #tbl - for i = size, 1, -1 do - local rand = math.random(size) - tbl[i], tbl[rand] = tbl[rand], tbl[i] - end - return tbl -end - - local Blueprint = {} Blueprint.__index = Blueprint @@ -92,13 +82,10 @@ function _M.new(dao, db) local upstream_name_seq = new_sequence("upstream-%d") res.upstreams = new_blueprint(dao.upstreams, function(overrides) local slots = overrides.slots or 100 - local orderlist = {} - for i = 1, slots do orderlist[i] = i end return { name = upstream_name_seq:next(), slots = slots, - orderlist = shuffle(orderlist), } end) @@ -289,7 +276,7 @@ function _M.new(dao, db) } end) - res.response_rate_limiting_plugins = new_blueprint(dao.plugins, function() + res.response_ratelimiting_plugins = new_blueprint(dao.plugins, function() return { name = "response-ratelimiting", config = {}, @@ -303,6 +290,13 @@ function _M.new(dao, db) } end) + res.statsd_plugins = new_blueprint(dao.plugins, function() + return { + name = "statsd", + config = {}, + } + end) + return res end