Skip to content

Commit

Permalink
Merge b60b478 into d825cff
Browse files Browse the repository at this point in the history
  • Loading branch information
alisaifee committed May 15, 2020
2 parents d825cff + b60b478 commit 1a7c39b
Show file tree
Hide file tree
Showing 11 changed files with 160 additions and 126 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ jobs:
pip install -r requirements/ci.txt
- name: Lint with flake8
run: |
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
flake8 . --count --ignore=C901,W503 --show-source --statistics
flake8 . --count --exit-zero
10 changes: 5 additions & 5 deletions flask_limiter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
Flask-Limiter extension for rate limiting
"""
"""Flask-Limiter extension for rate limiting."""
from ._version import get_versions
from .errors import RateLimitExceeded
from .extension import Limiter, HEADERS

__version__ = get_versions()['version']
del get_versions

from .errors import RateLimitExceeded
from .extension import Limiter, HEADERS
__all__ = ['RateLimitExceeded', 'Limiter', 'HEADERS']
8 changes: 3 additions & 5 deletions flask_limiter/errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
errors and exceptions
"""
"""errors and exceptions."""

from distutils.version import LooseVersion
from pkg_resources import get_distribution
Expand All @@ -20,8 +18,8 @@


class RateLimitExceeded(werkzeug_exception):
"""
exception raised when a rate limit is hit.
"""exception raised when a rate limit is hit.
The exception results in ``abort(429)`` being called.
"""
code = 429
Expand Down
180 changes: 110 additions & 70 deletions flask_limiter/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,28 +54,36 @@ class HEADERS:

class Limiter(object):
"""
:param app: :class:`flask.Flask` instance to initialize the extension
with.
:param list default_limits: a variable list of strings or callables returning strings denoting global
limits to apply to all routes. :ref:`ratelimit-string` for more details.
:param bool default_limits_per_method: whether default limits are applied per method, per route or as a combination
of all method per route.
:param list application_limits: a variable list of strings or callables returning strings for limits that
are applied to the entire application (i.e a shared limit for all routes)
:param function key_func: a callable that returns the domain to rate limit by.
:param bool headers_enabled: whether ``X-RateLimit`` response headers are written.
:param str strategy: the strategy to use. refer to :ref:`ratelimit-strategy`
:param str storage_uri: the storage location. refer to :ref:`ratelimit-conf`
:param dict storage_options: kwargs to pass to the storage implementation upon
instantiation.
:param bool auto_check: whether to automatically check the rate limit in the before_request
chain of the application. default ``True``
:param bool swallow_errors: whether to swallow errors when hitting a rate limit.
An exception will still be logged. default ``False``
:param list in_memory_fallback: a variable list of strings or callables returning strings denoting fallback
limits to apply when the storage is down.
:param bool in_memory_fallback_enabled: simply falls back to in memory storage
when the main storage is down and inherits the original limits.
The :class:`Limiter` class initializes the Flask-Limiter extension.
:param app: :class:`flask.Flask` instance to initialize the extension with.
:param list default_limits: a variable list of strings or callables
returning strings denoting global limits to apply to all routes.
:ref:`ratelimit-string` for more details.
:param bool default_limits_per_method: whether default limits are applied
per method, per route or as a combination of all method per route.
:param list application_limits: a variable list of strings or callables
returning strings for limits that are applied to the entire application
(i.e a shared limit for all routes)
:param function key_func: a callable that returns the domain to rate limit
by.
:param bool headers_enabled: whether ``X-RateLimit`` response headers are
written.
:param str strategy: the strategy to use.
Refer to :ref:`ratelimit-strategy`
:param str storage_uri: the storage location.
Refer to :ref:`ratelimit-conf`
:param dict storage_options: kwargs to pass to the storage implementation
upon instantiation.
:param bool auto_check: whether to automatically check the rate limit in
the before_request chain of the application. default ``True``
:param bool swallow_errors: whether to swallow errors when hitting a rate
limit. An exception will still be logged. default ``False``
:param list in_memory_fallback: a variable list of strings or callables
returning strings denoting fallback limits to apply when the storage is
down.
:param bool in_memory_fallback_enabled: simply falls back to in memory
storage when the main storage is down and inherits the original limits.
:param str key_prefix: prefix prepended to rate limiter keys.
"""

Expand Down Expand Up @@ -107,7 +115,10 @@ def __init__(
self._default_limits_per_method = default_limits_per_method
self._application_limits = []
self._in_memory_fallback = []
self._in_memory_fallback_enabled = in_memory_fallback_enabled or len(in_memory_fallback) > 0
self._in_memory_fallback_enabled = (
in_memory_fallback_enabled
or len(in_memory_fallback) > 0
)
self._exempt_routes = set()
self._request_filters = []
self._headers_enabled = headers_enabled
Expand All @@ -121,7 +132,8 @@ def __init__(
if not key_func:
warnings.warn(
"Use of the default `get_ipaddr` function is discouraged."
" Please refer to https://flask-limiter.readthedocs.org/#rate-limit-domain"
" Please refer to"
" https://flask-limiter.readthedocs.org/#rate-limit-domain"
" for the recommended configuration", UserWarning
)
if global_limits:
Expand Down Expand Up @@ -263,7 +275,10 @@ def init_app(self, app):
)
]
if not self._in_memory_fallback_enabled:
self._in_memory_fallback_enabled = fallback_enabled or len(self._in_memory_fallback) > 0
self._in_memory_fallback_enabled = (
fallback_enabled
or len(self._in_memory_fallback) > 0
)

if self._in_memory_fallback_enabled:
self._fallback_storage = MemoryStorage()
Expand Down Expand Up @@ -333,18 +348,24 @@ def __inject_headers(self, response):
response.headers.add(
self._header_mapping[HEADERS.REMAINING], window_stats[1]
)
response.headers.add(self._header_mapping[HEADERS.RESET], reset_in)
response.headers.add(
self._header_mapping[HEADERS.RESET], reset_in
)

# response may have an existing retry after
existing_retry_after_header = response.headers.get('Retry-After')
existing_retry_after_header = response.headers.get(
'Retry-After'
)

if existing_retry_after_header is not None:
# might be in http-date format
retry_after = parse_date(existing_retry_after_header)

# parse_date failure returns None
if retry_after is None:
retry_after = time.time() + int(existing_retry_after_header)
retry_after = time.time() + int(
existing_retry_after_header
)

if isinstance(retry_after, datetime.datetime):
retry_after = time.mktime(retry_after.timetuple())
Expand All @@ -357,7 +378,7 @@ def __inject_headers(self, response):
self._retry_after == 'http-date' and http_date(reset_in)
or int(reset_in - time.time())
)
except:
except: # noqa: E722
if self._in_memory_fallback and not self._storage_dead:
self.logger.warn(
"Rate limit storage unreachable - falling back to"
Expand Down Expand Up @@ -416,7 +437,8 @@ def __check_request_limit(self, in_middleware=True):
"%s.%s" % (view_func.__module__, view_func.__name__)
if view_func else ""
)
if (not request.endpoint
if (
not request.endpoint
or not self.enabled
or view_func == current_app.send_static_file
or name in self._exempt_routes
Expand Down Expand Up @@ -458,11 +480,13 @@ def __check_request_limit(self, in_middleware=True):
dynamic_limits.extend(list(lim))
except ValueError as e:
self.logger.error(
"failed to load ratelimit for view function %s (%s)",
"failed to load ratelimit for "
"view function %s (%s)",
name, e
)
if request.blueprint:
if (request.blueprint in self._blueprint_dynamic_limits
if (
request.blueprint in self._blueprint_dynamic_limits
and not dynamic_limits
):
for limit_group in self._blueprint_dynamic_limits[
Expand Down Expand Up @@ -503,19 +527,27 @@ def __check_request_limit(self, in_middleware=True):
)
if not all_limits:
route_limits = limits + dynamic_limits
all_limits = list(itertools.chain(*self._application_limits)) if in_middleware else []
all_limits = list(
itertools.chain(*self._application_limits)
) if in_middleware else []
all_limits += route_limits
explicit_limits_exempt = all(
limit.method_exempt for limit in route_limits
)
combined_defaults = all(
not limit.override_defaults for limit in route_limits
)
before_request_context = (
in_middleware and name in self.__marked_for_limiting
)
if (
(
all(limit.method_exempt for limit in route_limits)
or all(not limit.override_defaults for limit in route_limits)
)
and not (in_middleware and name in self.__marked_for_limiting)
or implicit_decorator
(explicit_limits_exempt or combined_defaults)
and not before_request_context
or implicit_decorator
):
all_limits += list(itertools.chain(*self._default_limits))
all_limits += list(itertools.chain(*self._default_limits))
self.__evaluate_limits(endpoint, all_limits)
except Exception as e: # no qa
except Exception as e:
if isinstance(e, RateLimitExceeded):
six.reraise(*sys.exc_info())
if self._in_memory_fallback_enabled and not self._storage_dead:
Expand Down Expand Up @@ -594,7 +626,10 @@ def _inner(obj):

@wraps(obj)
def __inner(*a, **k):
if self._auto_check and not g.get("_rate_limiting_complete"):
if (
self._auto_check
and not g.get("_rate_limiting_complete")
):
self.__check_request_limit(False)
g._rate_limiting_complete = True
return obj(*a, **k)
Expand All @@ -614,20 +649,21 @@ def limit(
"""
decorator to be used for rate limiting individual routes or blueprints.
:param limit_value: rate limit string or a callable that returns a string.
:ref:`ratelimit-string` for more details.
:param function key_func: function/lambda to extract the unique identifier for
the rate limit. defaults to remote address of the request.
:param bool per_method: whether the limit is sub categorized into the http
method of the request.
:param list methods: if specified, only the methods in this list will be rate
limited (default: None).
:param error_message: string (or callable that returns one) to override the
error message used in the response.
:param function exempt_when: function/lambda used to decide if the rate limit
should skipped.
:param bool override_defaults: whether the decorated limit overrides the default
limits. (default: True)
:param limit_value: rate limit string or a callable that returns a
string. :ref:`ratelimit-string` for more details.
:param function key_func: function/lambda to extract the unique
identifier for the rate limit. defaults to remote address of the
request.
:param bool per_method: whether the limit is sub categorized into the
http method of the request.
:param list methods: if specified, only the methods in this list will
be rate limited (default: None).
:param error_message: string (or callable that returns one) to override
the error message used in the response.
:param function exempt_when: function/lambda used to decide if the rate
limit should skipped.
:param bool override_defaults: whether the decorated limit overrides
the default limits. (default: True)
"""
return self.__limit_decorator(
limit_value,
Expand All @@ -651,18 +687,19 @@ def shared_limit(
"""
decorator to be applied to multiple routes sharing the same rate limit.
:param limit_value: rate limit string or a callable that returns a string.
:ref:`ratelimit-string` for more details.
:param limit_value: rate limit string or a callable that returns a
string. :ref:`ratelimit-string` for more details.
:param scope: a string or callable that returns a string
for defining the rate limiting scope.
:param function key_func: function/lambda to extract the unique identifier for
the rate limit. defaults to remote address of the request.
:param error_message: string (or callable that returns one) to override the
error message used in the response.
:param function exempt_when: function/lambda used to decide if the rate limit
should skipped.
:param bool override_defaults: whether the decorated limit overrides the default
limits. (default: True)
:param function key_func: function/lambda to extract the unique
identifier for the rate limit. defaults to remote address of the
request.
:param error_message: string (or callable that returns one) to override
the error message used in the response.
:param function exempt_when: function/lambda used to decide if the rate
limit should skipped.
:param bool override_defaults: whether the decorated limit overrides
the default limits. (default: True)
"""
return self.__limit_decorator(
limit_value,
Expand All @@ -676,7 +713,8 @@ def shared_limit(

def exempt(self, obj):
"""
decorator to mark a view or all views in a blueprint as exempt from rate limits.
decorator to mark a view or all views in a blueprint as exempt from
rate limits.
"""
if not isinstance(obj, Blueprint):
name = "%s.%s" % (obj.__module__, obj.__name__)
Expand All @@ -700,7 +738,9 @@ def request_filter(self, fn):

def raise_global_limits_warning(self):
warnings.warn(
"global_limits was a badly name configuration since it is actually a default limit and not a "
"globally shared limit. Use default_limits if you want to provide a default or use application_limits "
"if you intend to really have a global shared limit", UserWarning
"global_limits was a badly name configuration since it is "
"actually a default limit and not a globally shared limit. Use "
"default_limits if you want to provide a default or use "
"application_limits if you intend to really have a global "
"shared limit", UserWarning
)
13 changes: 6 additions & 7 deletions flask_limiter/util.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
"""
"""
""""""

from flask import request


def get_ipaddr():
"""
:return: the ip address for the current request (or 127.0.0.1 if none found)
based on the X-Forwarded-For headers.
"""
:return: the ip address for the current request
(or 127.0.0.1 if none found) based on the X-Forwarded-For headers.
"""
if request.access_route:
return request.access_route[0]
else:
Expand All @@ -18,6 +16,7 @@ def get_ipaddr():

def get_remote_address():
"""
:return: the ip address for the current request (or 127.0.0.1 if none found)
:return: the ip address for the current request
(or 127.0.0.1 if none found)
"""
return request.remote_addr or '127.0.0.1'
11 changes: 8 additions & 3 deletions flask_limiter/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,16 @@ def scope(self):
@property
def method_exempt(self):
"""Check if the limit is not applicable for this method"""
return self.methods is not None and request.method.lower() not in self.methods
return (
self.methods is not None
and request.method.lower() not in self.methods
)


class LimitGroup(object):
"""
represents a group of related limits either from a string or a callable that returns one
represents a group of related limits either from a string or a callable
that returns one
"""

def __init__(
Expand All @@ -63,5 +67,6 @@ def __iter__(self):
for limit in limit_items:
yield Limit(
limit, self.key_function, self.__scope, self.per_method,
self.methods, self.error_message, self.exempt_when, self.override_defaults
self.methods, self.error_message, self.exempt_when,
self.override_defaults
)
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ parentdir_prefix = flask-limiter-

[flake8]
exclude = build/**,doc/**,_version.py,version.py,versioneer.py
max_complexity = 10
Loading

0 comments on commit 1a7c39b

Please sign in to comment.