From 353ab2e3728e38db07cc3c458c237d3e13ae89ac Mon Sep 17 00:00:00 2001 From: Ali-Akber Saifee Date: Tue, 24 Jun 2014 15:26:08 +0800 Subject: [PATCH] implementation for shared limits (#8) --- doc/source/index.rst | 68 ++++++++++++++++++++++++++++++++ flask_limiter/extension.py | 56 +++++++++++++++++++-------- tests/test_flask_ext.py | 79 +++++++++++++++++++++++++++++++++----- 3 files changed, 178 insertions(+), 25 deletions(-) diff --git a/doc/source/index.rst b/doc/source/index.rst index d9326d5c..992e0271 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -142,6 +142,74 @@ instance are .. note:: The callable is called from within a :ref:`flask request context `. +.. _ratelimit-decorator-shared-limit: + +:meth:`Limiter.shared_limit` + For scenarios where a rate limit should be shared by multiple routes + (For example when you want to protect routes using the same resource + with an umbrella rate limit). + + Named shared limit + + .. code-block:: python + + mysql_limit = limiter.limit("100/hour", scope="mysql") + + @app.route("..") + @mysql_limit + def r1(): + ... + + @app.route("..") + @mysql_limit + def r2(): + ... + + Anonymous shared limit: when no scope is provided a unique id will be used. + + .. code-block:: python + + shared_limit = limiter.limit("100/hour") + + @app.route("..") + @shared_limit + def r1(): + ... + + @app.route("..") + @shared_limit + def r2(): + ... + + + Dynamic shared limit: when a callable is passed as scope, the return value + of the function will be used as the scope. + + .. code-block:: python + + def host_scope(): + return request.host + host_limit = limiter.shared_limit("100/hour", scope=host_scope) + + @app.route("..") + @host_limit + def r1(): + ... + + @app.route("..") + @host_limit + def r2(): + ... + + + .. note:: Shared rate limits provide the same conveniences as individual rate limits, + + * They can be chained with other shared limits or other individual limits + * They accept keying functions + * accept callables to determine the rate limit value + + + .. _ratelimit-decorator-exempt: :meth:`Limiter.exempt` diff --git a/flask_limiter/extension.py b/flask_limiter/extension.py index 4755c5b3..da42c6a2 100644 --- a/flask_limiter/extension.py +++ b/flask_limiter/extension.py @@ -4,6 +4,7 @@ from functools import wraps import logging +import uuid from flask import request, current_app, g @@ -140,7 +141,9 @@ def __check_request_limit(self): failed_limit = None limit_for_header = None for key_func, limit, scope in (limits + dynamic_limits or self.global_limits): - limit_scope = scope or endpoint + limit_scope = ( + scope if not callable(scope) else scope(endpoint) + ) or endpoint if not limit_for_header or limit < limit_for_header[0]: limit_for_header = (limit, key_func(), limit_scope) if not self.limiter.hit(limit, key_func(), limit_scope): @@ -155,42 +158,63 @@ def __check_request_limit(self): if failed_limit: raise RateLimitExceeded(failed_limit) - def limit(self, limit_value, key_func=None, scope=None): - """ - decorator to be used for rate limiting specific routes. - - :param limit_value: rate limit string or a callable that returns a string. - :ref:`ratelimit-string` for more details. - :param key_func: function/lambda to extract the unique identifier for - the rate limit. defaults to remote address of the request. - :param scope: a string for defining the rate limiting scope. - defaults to the route's endpoint. - :return: - """ + def __limit_decorator(self, limit_value, + key_func=None, shared=False, + scope=None): + _scope = scope or uuid.uuid1().hex if shared else None def _inner(fn): name = "%s.%s" % (fn.__module__, fn.__name__) + @wraps(fn) def __inner(*a, **k): return fn(*a, **k) func = key_func or self.key_func if callable(limit_value): self.dynamic_route_limits.setdefault(name, []).append( - (func, limit_value, scope) + (func, limit_value, _scope) ) else: try: self.route_limits.setdefault(name, []).extend( - [(func, limit, scope) for limit in parse_many(limit_value)] + [(func, limit, _scope) for limit in parse_many(limit_value)] ) except ValueError as e: self.logger.error( "failed to configure view function %s (%s)", name, e ) - return __inner return _inner + + def limit(self, limit_value, key_func=None): + """ + decorator to be used for rate limiting individual routes. + + :param limit_value: rate limit string or a callable that returns a string. + :ref:`ratelimit-string` for more details. + :param key_func: function/lambda to extract the unique identifier for + the rate limit. defaults to remote address of the request. + :return: + """ + return self.__limit_decorator(limit_value, key_func) + + + def shared_limit(self, limit_value, key_func=None, scope=None): + """ + decorator to be applied to multiple routes sharing the same rate limit. + + :param limit_value: rate limit string or a callable that returns a string. + :ref:`ratelimit-string` for more details. + :param key_func: function/lambda to extract the unique identifier for + the rate limit. defaults to remote address of the request. + :param scope: a string or callable that returns a string + for defining the rate limiting scope. Defaults to a :func:`uuid.uuid1` value. + """ + return self.__limit_decorator(limit_value, key_func, True, scope) + + + def exempt(self, fn): """ decorator to mark a view as exempt from rate limits. diff --git a/tests/test_flask_ext.py b/tests/test_flask_ext.py index f7827c77..4f038b85 100644 --- a/tests/test_flask_ext.py +++ b/tests/test_flask_ext.py @@ -423,25 +423,47 @@ def t(): str(int(time.time() + 49)) ) - def test_scope(self): - app = Flask(__name__) - limiter = Limiter(app) - mock_handler = mock.Mock() - mock_handler.level = logging.INFO - limiter.logger.addHandler(mock_handler) + def test_unnamed_shared_limit(self): + app, limiter = self.build_app() + shared_limit_a = limiter.shared_limit("1/minute") + shared_limit_b = limiter.shared_limit("1/minute") + @app.route("/t1") + @shared_limit_a + def route1(): + return "route1" + + @app.route("/t2") + @shared_limit_a + def route2(): + return "route2" + + @app.route("/t3") + @shared_limit_b + def route3(): + return "route3" + + with hiro.Timeline().freeze() as timeline: + with app.test_client() as cli: + self.assertEqual(200, cli.get("/t1").status_code) + self.assertEqual(200, cli.get("/t3").status_code) + self.assertEqual(429, cli.get("/t2").status_code) + def test_named_shared_limit(self): + app, limiter = self.build_app() + shared_limit_a = limiter.shared_limit("1/minute", scope='a') + shared_limit_b = limiter.shared_limit("1/minute", scope='b') @app.route("/t1") - @limiter.limit("1/minute", scope="a") + @shared_limit_a def route1(): return "route1" @app.route("/t2") - @limiter.limit("1/minute", scope="a") + @shared_limit_a def route2(): return "route2" @app.route("/t3") - @limiter.limit("1/minute", scope="b") + @shared_limit_b def route3(): return "route3" @@ -451,6 +473,45 @@ def route3(): self.assertEqual(200, cli.get("/t3").status_code) self.assertEqual(429, cli.get("/t2").status_code) + def test_dynamic_shared_limit(self): + app, limiter = self.build_app() + fn_a = mock.Mock() + fn_b = mock.Mock() + fn_a.return_value = "foo" + fn_b.return_value = "bar" + + + dy_limit_a = limiter.shared_limit("1/minute", scope=fn_a) + dy_limit_b = limiter.shared_limit("1/minute", scope=fn_b) + + + @app.route("/t1") + @dy_limit_a + def t1(): + return "route1" + + @app.route("/t2") + @dy_limit_a + def t2(): + return "route2" + + @app.route("/t3") + @dy_limit_b + def t3(): + return "route3" + + with hiro.Timeline().freeze(): + with app.test_client() as cli: + self.assertEqual(200, cli.get("/t1").status_code) + self.assertEqual(200, cli.get("/t3").status_code) + self.assertEqual(429, cli.get("/t2").status_code) + self.assertEqual(429, cli.get("/t3").status_code) + self.assertEqual(2, fn_a.call_count) + self.assertEqual(2, fn_b.call_count) + fn_b.assert_called_with("t3") + fn_a.assert_has_calls([mock.call("t1"), mock.call("t2")]) + + def test_whitelisting(self): app = Flask(__name__)