Skip to content

Commit

Permalink
Merge pull request #14 from alisaifee/blueprints
Browse files Browse the repository at this point in the history
Rate limits for blueprints
  • Loading branch information
alisaifee committed Jul 13, 2014
2 parents 51c7df7 + b73a14c commit f374d3b
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 28 deletions.
39 changes: 39 additions & 0 deletions doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,45 @@ work. You can add rate limits to your view classes using the following approach.
The above approach has been tested with sub-classes of :class:`flask.views.View`,
:class:`flask.views.MethodView` and :class:`flask.ext.restful.Resource`.

Rate limiting all routes in a :class:`flask.Blueprint`
------------------------------------------------------
:meth:`Limiter.limit`, :meth:`Limiter.shared_limit` & :meth:`Limiter.exempt` can
all be applied to :class:`flask.Blueprint` instances as well. In the following example
the **login** Blueprint has a special rate limit applied to all its routes, while the **help**
Blueprint is exempt from all rate limits. The **regular** Blueprint follows the global rate limits.


.. code-block:: python
app = Flask(__name__)
login = Blueprint("login", __name__, url_prefix = "/login")
regular = Blueprint("regular", __name__, url_prefix = "/regular")
doc = Blueprint("doc", __name__, url_prefix = "/doc")
@doc.route("/")
def doc_index():
return "doc"
@regular.route("/")
def regular_index():
return "regular"
@login.route("/")
def login_index():
return "login"
limiter = Limiter(app, global_limits = ["1/second"])
limiter.limit("60/hour")(login)
limiter.exempt(doc)
app.register_blueprint(doc)
app.register_blueprint(login)
app.register_blueprint(regular)
.. _logging:

Logging
Expand Down
98 changes: 71 additions & 27 deletions flask_limiter/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import wraps
import logging

from flask import request, current_app, g
from flask import request, current_app, g, Blueprint

from .errors import RateLimitExceeded, ConfigurationError
from flask.ext.limiter.storage import storage_from_string
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(self, app=None
self.app = app
self.enabled = True
self.global_limits = []
self.exempt_routes = []
self.exempt_routes = set()
self.request_filters = []
self.headers_enabled = headers_enabled
self.strategy = strategy
Expand All @@ -76,6 +76,9 @@ def __init__(self, app=None
)
self.route_limits = {}
self.dynamic_route_limits = {}
self.blueprint_limits = {}
self.blueprint_dynamic_limits = {}
self.blueprint_exempt = set()
self.storage = self.limiter = None
self.key_func = key_func
self.logger = logging.getLogger("flask-limiter")
Expand Down Expand Up @@ -151,6 +154,7 @@ def __check_request_limit(self):
or name in self.exempt_routes
or not self.enabled
or any(fn() for fn in self.request_filters)
or request.blueprint in self.blueprint_exempt
):
return
limits = (
Expand All @@ -171,6 +175,26 @@ def __check_request_limit(self):
"failed to load ratelimit for view function %s (%s)"
, name, e
)
if request.blueprint:
if (request.blueprint in self.blueprint_dynamic_limits
and not dynamic_limits
):
for lim in self.blueprint_dynamic_limits[request.blueprint]:
try:
dynamic_limits.extend(
ExtLimit(
limit, lim.key_func, lim.scope, lim.per_method
) for limit in parse_many(lim.limit)
)
except ValueError as e:
self.logger.error(
"failed to load ratelimit for blueprint %s (%s)"
, request.blueprint, e
)
if (request.blueprint in self.blueprint_limits
and not limits
):
limits.extend(self.blueprint_limits[request.blueprint])

failed_limit = None
limit_for_header = None
Expand Down Expand Up @@ -198,35 +222,52 @@ def __limit_decorator(self, limit_value,
per_method=False):
_scope = scope if shared else None

def _inner(fn):
name = "%s.%s" % (fn.__module__, fn.__name__)

@wraps(fn)
def __inner(*a, **k):
return fn(*a, **k)
def _inner(obj):
func = key_func or self.key_func
is_route = not isinstance(obj, Blueprint)
name = "%s.%s" % (obj.__module__, obj.__name__) if is_route else obj.name
dynamic_limit, static_limits = None, []
if callable(limit_value):
self.dynamic_route_limits.setdefault(name, []).append(
ExtLimit(limit_value, func, _scope, per_method)
)
dynamic_limit = ExtLimit(limit_value, func, _scope, per_method)
else:
try:
self.route_limits.setdefault(name, []).extend(
[ExtLimit(
limit, func, _scope, per_method
) for limit in parse_many(limit_value)]
)
static_limits = [ExtLimit(
limit, func, _scope, per_method
) for limit in parse_many(limit_value)]
except ValueError as e:
self.logger.error(
"failed to configure view function %s (%s)", name, e
"failed to configure %s %s (%s)",
"view function" if is_route else "blueprint", name, e
)
return __inner
if not isinstance(obj, Blueprint):
@wraps(obj)
def __inner(*a, **k):
return obj(*a, **k)
if dynamic_limit:
self.dynamic_route_limits.setdefault(name, []).append(
dynamic_limit
)
else:
self.route_limits.setdefault(name, []).extend(
static_limits
)
return __inner
else:
if dynamic_limit:
self.blueprint_dynamic_limits.setdefault(name, []).append(
dynamic_limit
)
else:
self.blueprint_limits.setdefault(name, []).extend(
static_limits
)

return _inner


def limit(self, limit_value, key_func=None, per_method=False):
"""
decorator to be used for rate limiting individual routes.
decorator to be used for rate limiting individual routes or blueprints.
:param limit_value: rate limit string or a callable that returns a string.
:ref:`ratelimit-string` for more details.
Expand Down Expand Up @@ -254,16 +295,19 @@ def shared_limit(self, limit_value, scope, key_func=None):



def exempt(self, fn):
def exempt(self, obj):
"""
decorator to mark a view as exempt from rate limits.
decorator to mark a view or all views in a route as exempt from rate limits.
"""
name = "%s.%s" % (fn.__module__, fn.__name__)
@wraps(fn)
def __inner(*a, **k):
return fn(*a, **k)
self.exempt_routes.append(name)
return __inner
if not isinstance(obj, Blueprint):
name = "%s.%s" % (obj.__module__, obj.__name__)
@wraps(obj)
def __inner(*a, **k):
return obj(*a, **k)
self.exempt_routes.add(name)
return __inner
else:
self.blueprint_exempt.add(obj.name)

def request_filter(self, fn):
"""
Expand Down
96 changes: 95 additions & 1 deletion tests/test_flask_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,56 @@ def t2():
self.assertEqual(cli.get("/t2").status_code, 200)
self.assertEqual(cli.get("/t2").status_code, 429)


def test_register_blueprint(self):
app, limiter = self.build_app(global_limits = ["1/minute"])
bp_1 = Blueprint("bp1", __name__)
bp_2 = Blueprint("bp2", __name__)
bp_3 = Blueprint("bp3", __name__)

@bp_1.route("/t1")
def t1():
return "test"

@bp_1.route("/t2")
def t2():
return "test"

@bp_2.route("/t3")
def t3():
return "test"

@bp_3.route("/t4")
def t4():
return "test"

app.register_blueprint(bp_1)
app.register_blueprint(bp_2)
app.register_blueprint(bp_3)

limiter.limit("1/second")(bp_1)
limiter.exempt(bp_3)

with hiro.Timeline().freeze() as timeline:
with app.test_client() as cli:
self.assertEqual(cli.get("/t1").status_code, 200)
self.assertEqual(cli.get("/t1").status_code, 429)
timeline.forward(1)
self.assertEqual(cli.get("/t1").status_code, 200)
self.assertEqual(cli.get("/t2").status_code, 200)
self.assertEqual(cli.get("/t2").status_code, 429)
timeline.forward(1)
self.assertEqual(cli.get("/t2").status_code, 200)

self.assertEqual(cli.get("/t3").status_code, 200)
for i in range(0,10):
timeline.forward(1)
self.assertEqual(cli.get("/t3").status_code, 429)

for i in range(0,10):
self.assertEqual(cli.get("/t4").status_code, 200)


def test_disabled_flag(self):
app, limiter = self.build_app(
config={C.ENABLED: False},
Expand Down Expand Up @@ -292,9 +342,53 @@ def t1():
with hiro.Timeline().freeze() as timeline:
self.assertEqual(cli.get("/t1").status_code, 200)
self.assertEqual(cli.get("/t1").status_code, 429)
self.assertTrue("failed to configure view function" in mock_handler.handle.call_args_list[0][0][0].msg)
self.assertTrue("failed to configure" in mock_handler.handle.call_args_list[0][0][0].msg)
self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[1][0][0].msg)

def test_invalid_decorated_static_limit_blueprint(self):
app = Flask(__name__)
limiter = Limiter(app, global_limits=["1/second"])
mock_handler = mock.Mock()
mock_handler.level = logging.INFO
limiter.logger.addHandler(mock_handler)
bp = Blueprint("bp1", __name__)

@bp.route("/t1")
def t1():
return "42"
limiter.limit("2/sec")(bp)
app.register_blueprint(bp)

with app.test_client() as cli:
with hiro.Timeline().freeze() as timeline:
self.assertEqual(cli.get("/t1").status_code, 200)
self.assertEqual(cli.get("/t1").status_code, 429)
self.assertTrue("failed to configure" in mock_handler.handle.call_args_list[0][0][0].msg)
self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[1][0][0].msg)

def test_invalid_decorated_dynamic_limits_blueprint(self):
app = Flask(__name__)
app.config.setdefault("X", "2 per sec")
limiter = Limiter(app, global_limits=["1/second"])
mock_handler = mock.Mock()
mock_handler.level = logging.INFO
limiter.logger.addHandler(mock_handler)
bp = Blueprint("bp1", __name__)
@bp.route("/t1")
def t1():
return "42"

limiter.limit(lambda: current_app.config.get("X"))(bp)
app.register_blueprint(bp)

with app.test_client() as cli:
with hiro.Timeline().freeze() as timeline:
self.assertEqual(cli.get("/t1").status_code, 200)
self.assertEqual(cli.get("/t1").status_code, 429)
self.assertEqual(mock_handler.handle.call_count, 3)
self.assertTrue("failed to load ratelimit" in mock_handler.handle.call_args_list[0][0][0].msg)
self.assertTrue("failed to load ratelimit" in mock_handler.handle.call_args_list[1][0][0].msg)
self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[2][0][0].msg)

def test_multiple_apps(self):
app1 = Flask(__name__)
Expand Down

0 comments on commit f374d3b

Please sign in to comment.