-
-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move each storage impl to its own file.
- Loading branch information
Showing
11 changed files
with
898 additions
and
838 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from six.moves import urllib | ||
|
||
from limits.errors import ConfigurationError | ||
from .memory import MemoryStorage | ||
|
||
from .base import Storage | ||
from .registry import SCHEMES | ||
from .redis import RedisStorage | ||
from .redis_cluster import RedisClusterStorage | ||
from .redis_sentinel import RedisSentinelStorage | ||
from .memcached import MemcachedStorage | ||
from .gae_memcached import GAEMemcachedStorage | ||
|
||
|
||
def storage_from_string(storage_string, **options): | ||
""" | ||
factory function to get an instance of the storage class based | ||
on the uri of the storage | ||
:param storage_string: a string of the form method://host:port | ||
:return: an instance of :class:`flask_limiter.storage.Storage` | ||
""" | ||
scheme = urllib.parse.urlparse(storage_string).scheme | ||
if scheme not in SCHEMES: | ||
raise ConfigurationError( | ||
"unknown storage scheme : %s" % storage_string | ||
) | ||
return SCHEMES[scheme](storage_string, **options) | ||
|
||
|
||
__all__ = [ | ||
"storage_from_string", | ||
"Storage", | ||
"MemoryStorage", | ||
"RedisStorage", | ||
"RedisClusterStorage", | ||
"RedisSentinelStorage", | ||
"MemcachedStorage", | ||
"GAEMemcachedStorage" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import threading | ||
from abc import ABCMeta, abstractmethod | ||
|
||
import six | ||
|
||
from limits.storage.registry import StorageRegistry | ||
|
||
|
||
@six.add_metaclass(StorageRegistry) | ||
@six.add_metaclass(ABCMeta) | ||
class Storage(object): | ||
""" | ||
Base class to extend when implementing a storage backend. | ||
""" | ||
|
||
def __init__(self, uri=None, **options): | ||
self.lock = threading.RLock() | ||
|
||
@abstractmethod | ||
def incr(self, key, expiry, elastic_expiry=False): | ||
""" | ||
increments the counter for a given rate limit key | ||
:param str key: the key to increment | ||
:param int expiry: amount in seconds for the key to expire in | ||
:param bool elastic_expiry: whether to keep extending the rate limit | ||
window every hit. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def get(self, key): | ||
""" | ||
:param str key: the key to get the counter value for | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def get_expiry(self, key): | ||
""" | ||
:param str key: the key to get the expiry for | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def check(self): | ||
""" | ||
check if storage is healthy | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def reset(self): | ||
""" | ||
reset storage to clear limits | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def clear(self, key): | ||
""" | ||
resets the rate limit key | ||
:param str key: the key to clear rate limits for | ||
""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import time | ||
|
||
from .memcached import MemcachedStorage | ||
|
||
|
||
class GAEMemcachedStorage(MemcachedStorage): | ||
""" | ||
rate limit storage with GAE memcache as backend | ||
""" | ||
MAX_CAS_RETRIES = 10 | ||
STORAGE_SCHEME = ["gaememcached"] | ||
|
||
def __init__(self, uri, **options): | ||
options["library"] = "google.appengine.api.memcache" | ||
super(GAEMemcachedStorage, self).__init__(uri, **options) | ||
|
||
def incr(self, key, expiry, elastic_expiry=False): | ||
""" | ||
increments the counter for a given rate limit key | ||
:param str key: the key to increment | ||
:param int expiry: amount in seconds for the key to expire in | ||
:param bool elastic_expiry: whether to keep extending the rate limit | ||
window every hit. | ||
""" | ||
if not self.call_memcached_func(self.storage.add, key, 1, expiry): | ||
if elastic_expiry: | ||
# CAS id is set as state on the client object in GAE memcache | ||
value = self.storage.gets(key) | ||
retry = 0 | ||
while ( | ||
not self.call_memcached_func( | ||
self.storage.cas, key, | ||
int(value or 0) + 1, expiry | ||
) and retry < self.MAX_CAS_RETRIES | ||
): | ||
value = self.storage.gets(key) | ||
retry += 1 | ||
self.call_memcached_func( | ||
self.storage.set, key + "/expires", expiry + time.time(), | ||
expiry | ||
) | ||
return int(value or 0) + 1 | ||
else: | ||
return self.storage.incr(key, 1) | ||
self.call_memcached_func( | ||
self.storage.set, key + "/expires", expiry + time.time(), expiry | ||
) | ||
return 1 | ||
|
||
def check(self): | ||
""" | ||
check if storage is healthy | ||
""" | ||
try: | ||
self.call_memcached_func(self.storage.get_stats) | ||
return True | ||
except: # noqa | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import inspect | ||
import threading | ||
import time | ||
|
||
from six.moves import urllib | ||
|
||
from ..errors import ConfigurationError | ||
from ..util import get_dependency | ||
from .base import Storage | ||
|
||
|
||
class MemcachedStorage(Storage): | ||
""" | ||
Rate limit storage with memcached as backend. | ||
Depends on the `pymemcache` library. | ||
""" | ||
MAX_CAS_RETRIES = 10 | ||
STORAGE_SCHEME = ["memcached"] | ||
|
||
def __init__(self, uri, **options): | ||
""" | ||
:param str uri: memcached location of the form | ||
`memcached://host:port,host:port`, `memcached:///var/tmp/path/to/sock` | ||
:param options: all remaining keyword arguments are passed | ||
directly to the constructor of :class:`pymemcache.client.base.Client` | ||
:raise ConfigurationError: when `pymemcache` is not available | ||
""" | ||
parsed = urllib.parse.urlparse(uri) | ||
self.hosts = [] | ||
for loc in parsed.netloc.strip().split(","): | ||
if not loc: | ||
continue | ||
host, port = loc.split(":") | ||
self.hosts.append((host, int(port))) | ||
else: | ||
# filesystem path to UDS | ||
if parsed.path and not parsed.netloc and not parsed.port: | ||
self.hosts = [parsed.path] | ||
|
||
self.library = options.pop('library', 'pymemcache.client') | ||
self.cluster_library = options.pop('library', 'pymemcache.client.hash') | ||
self.client_getter = options.pop('client_getter', self.get_client) | ||
self.options = options | ||
|
||
if not get_dependency(self.library): | ||
raise ConfigurationError( | ||
"memcached prerequisite not available." | ||
" please install %s" % self.library | ||
) # pragma: no cover | ||
self.local_storage = threading.local() | ||
self.local_storage.storage = None | ||
|
||
def get_client(self, module, hosts, **kwargs): | ||
""" | ||
returns a memcached client. | ||
:param module: the memcached module | ||
:param hosts: list of memcached hosts | ||
:return: | ||
""" | ||
return ( | ||
module.HashClient(hosts, **kwargs) | ||
if len(hosts) > 1 else module.Client(*hosts, **kwargs) | ||
) | ||
|
||
def call_memcached_func(self, func, *args, **kwargs): | ||
if 'noreply' in kwargs: | ||
argspec = inspect.getargspec(func) | ||
if not ('noreply' in argspec.args or argspec.keywords): | ||
kwargs.pop('noreply') # noqa | ||
return func(*args, **kwargs) | ||
|
||
@property | ||
def storage(self): | ||
""" | ||
lazily creates a memcached client instance using a thread local | ||
""" | ||
if not ( | ||
hasattr(self.local_storage, "storage") | ||
and self.local_storage.storage | ||
): | ||
self.local_storage.storage = self.client_getter( | ||
get_dependency( | ||
self.cluster_library if len(self.hosts) > 1 | ||
else self.library | ||
), | ||
self.hosts, **self.options | ||
) | ||
|
||
return self.local_storage.storage | ||
|
||
def get(self, key): | ||
""" | ||
:param str key: the key to get the counter value for | ||
""" | ||
return int(self.storage.get(key) or 0) | ||
|
||
def clear(self, key): | ||
""" | ||
:param str key: the key to clear rate limits for | ||
""" | ||
self.storage.delete(key) | ||
|
||
def incr(self, key, expiry, elastic_expiry=False): | ||
""" | ||
increments the counter for a given rate limit key | ||
:param str key: the key to increment | ||
:param int expiry: amount in seconds for the key to expire in | ||
:param bool elastic_expiry: whether to keep extending the rate limit | ||
window every hit. | ||
""" | ||
if not self.call_memcached_func( | ||
self.storage.add, key, 1, expiry, noreply=False | ||
): | ||
if elastic_expiry: | ||
value, cas = self.storage.gets(key) | ||
retry = 0 | ||
while ( | ||
not self.call_memcached_func( | ||
self.storage.cas, key, | ||
int(value or 0) + 1, cas, expiry | ||
) and retry < self.MAX_CAS_RETRIES | ||
): | ||
value, cas = self.storage.gets(key) | ||
retry += 1 | ||
self.call_memcached_func( | ||
self.storage.set, | ||
key + "/expires", | ||
expiry + time.time(), | ||
expire=expiry, | ||
noreply=False | ||
) | ||
return int(value or 0) + 1 | ||
else: | ||
return self.storage.incr(key, 1) | ||
self.call_memcached_func( | ||
self.storage.set, | ||
key + "/expires", | ||
expiry + time.time(), | ||
expire=expiry, | ||
noreply=False | ||
) | ||
return 1 | ||
|
||
def get_expiry(self, key): | ||
""" | ||
:param str key: the key to get the expiry for | ||
""" | ||
return int(float(self.storage.get(key + "/expires") or time.time())) | ||
|
||
def check(self): | ||
""" | ||
check if storage is healthy | ||
""" | ||
try: | ||
self.call_memcached_func(self.storage.get, 'limiter-check') | ||
return True | ||
except: # noqa | ||
return False |
Oops, something went wrong.