Skip to content

Commit

Permalink
Move each storage impl to its own file.
Browse files Browse the repository at this point in the history
  • Loading branch information
alisaifee committed May 19, 2020
1 parent e72a8a1 commit 1a06bde
Show file tree
Hide file tree
Showing 11 changed files with 898 additions and 838 deletions.
832 changes: 0 additions & 832 deletions limits/storage.py

This file was deleted.

40 changes: 40 additions & 0 deletions limits/storage/__init__.py
@@ -0,0 +1,40 @@
from six.moves import urllib

from limits.errors import ConfigurationError
from .memory import MemoryStorage

from .base import Storage
from .registry import SCHEMES
from .redis import RedisStorage
from .redis_cluster import RedisClusterStorage
from .redis_sentinel import RedisSentinelStorage
from .memcached import MemcachedStorage
from .gae_memcached import GAEMemcachedStorage


def storage_from_string(storage_string, **options):
"""
factory function to get an instance of the storage class based
on the uri of the storage
:param storage_string: a string of the form method://host:port
:return: an instance of :class:`flask_limiter.storage.Storage`
"""
scheme = urllib.parse.urlparse(storage_string).scheme
if scheme not in SCHEMES:
raise ConfigurationError(
"unknown storage scheme : %s" % storage_string
)
return SCHEMES[scheme](storage_string, **options)


__all__ = [
"storage_from_string",
"Storage",
"MemoryStorage",
"RedisStorage",
"RedisClusterStorage",
"RedisSentinelStorage",
"MemcachedStorage",
"GAEMemcachedStorage"
]
65 changes: 65 additions & 0 deletions limits/storage/base.py
@@ -0,0 +1,65 @@
import threading
from abc import ABCMeta, abstractmethod

import six

from limits.storage.registry import StorageRegistry


@six.add_metaclass(StorageRegistry)
@six.add_metaclass(ABCMeta)
class Storage(object):
"""
Base class to extend when implementing a storage backend.
"""

def __init__(self, uri=None, **options):
self.lock = threading.RLock()

@abstractmethod
def incr(self, key, expiry, elastic_expiry=False):
"""
increments the counter for a given rate limit key
:param str key: the key to increment
:param int expiry: amount in seconds for the key to expire in
:param bool elastic_expiry: whether to keep extending the rate limit
window every hit.
"""
raise NotImplementedError

@abstractmethod
def get(self, key):
"""
:param str key: the key to get the counter value for
"""
raise NotImplementedError

@abstractmethod
def get_expiry(self, key):
"""
:param str key: the key to get the expiry for
"""
raise NotImplementedError

@abstractmethod
def check(self):
"""
check if storage is healthy
"""
raise NotImplementedError

@abstractmethod
def reset(self):
"""
reset storage to clear limits
"""
raise NotImplementedError

@abstractmethod
def clear(self, key):
"""
resets the rate limit key
:param str key: the key to clear rate limits for
"""
raise NotImplementedError
59 changes: 59 additions & 0 deletions limits/storage/gae_memcached.py
@@ -0,0 +1,59 @@
import time

from .memcached import MemcachedStorage


class GAEMemcachedStorage(MemcachedStorage):
"""
rate limit storage with GAE memcache as backend
"""
MAX_CAS_RETRIES = 10
STORAGE_SCHEME = ["gaememcached"]

def __init__(self, uri, **options):
options["library"] = "google.appengine.api.memcache"
super(GAEMemcachedStorage, self).__init__(uri, **options)

def incr(self, key, expiry, elastic_expiry=False):
"""
increments the counter for a given rate limit key
:param str key: the key to increment
:param int expiry: amount in seconds for the key to expire in
:param bool elastic_expiry: whether to keep extending the rate limit
window every hit.
"""
if not self.call_memcached_func(self.storage.add, key, 1, expiry):
if elastic_expiry:
# CAS id is set as state on the client object in GAE memcache
value = self.storage.gets(key)
retry = 0
while (
not self.call_memcached_func(
self.storage.cas, key,
int(value or 0) + 1, expiry
) and retry < self.MAX_CAS_RETRIES
):
value = self.storage.gets(key)
retry += 1
self.call_memcached_func(
self.storage.set, key + "/expires", expiry + time.time(),
expiry
)
return int(value or 0) + 1
else:
return self.storage.incr(key, 1)
self.call_memcached_func(
self.storage.set, key + "/expires", expiry + time.time(), expiry
)
return 1

def check(self):
"""
check if storage is healthy
"""
try:
self.call_memcached_func(self.storage.get_stats)
return True
except: # noqa
return False
160 changes: 160 additions & 0 deletions limits/storage/memcached.py
@@ -0,0 +1,160 @@
import inspect
import threading
import time

from six.moves import urllib

from ..errors import ConfigurationError
from ..util import get_dependency
from .base import Storage


class MemcachedStorage(Storage):
"""
Rate limit storage with memcached as backend.
Depends on the `pymemcache` library.
"""
MAX_CAS_RETRIES = 10
STORAGE_SCHEME = ["memcached"]

def __init__(self, uri, **options):
"""
:param str uri: memcached location of the form
`memcached://host:port,host:port`, `memcached:///var/tmp/path/to/sock`
:param options: all remaining keyword arguments are passed
directly to the constructor of :class:`pymemcache.client.base.Client`
:raise ConfigurationError: when `pymemcache` is not available
"""
parsed = urllib.parse.urlparse(uri)
self.hosts = []
for loc in parsed.netloc.strip().split(","):
if not loc:
continue
host, port = loc.split(":")
self.hosts.append((host, int(port)))
else:
# filesystem path to UDS
if parsed.path and not parsed.netloc and not parsed.port:
self.hosts = [parsed.path]

self.library = options.pop('library', 'pymemcache.client')
self.cluster_library = options.pop('library', 'pymemcache.client.hash')
self.client_getter = options.pop('client_getter', self.get_client)
self.options = options

if not get_dependency(self.library):
raise ConfigurationError(
"memcached prerequisite not available."
" please install %s" % self.library
) # pragma: no cover
self.local_storage = threading.local()
self.local_storage.storage = None

def get_client(self, module, hosts, **kwargs):
"""
returns a memcached client.
:param module: the memcached module
:param hosts: list of memcached hosts
:return:
"""
return (
module.HashClient(hosts, **kwargs)
if len(hosts) > 1 else module.Client(*hosts, **kwargs)
)

def call_memcached_func(self, func, *args, **kwargs):
if 'noreply' in kwargs:
argspec = inspect.getargspec(func)
if not ('noreply' in argspec.args or argspec.keywords):
kwargs.pop('noreply') # noqa
return func(*args, **kwargs)

@property
def storage(self):
"""
lazily creates a memcached client instance using a thread local
"""
if not (
hasattr(self.local_storage, "storage")
and self.local_storage.storage
):
self.local_storage.storage = self.client_getter(
get_dependency(
self.cluster_library if len(self.hosts) > 1
else self.library
),
self.hosts, **self.options
)

return self.local_storage.storage

def get(self, key):
"""
:param str key: the key to get the counter value for
"""
return int(self.storage.get(key) or 0)

def clear(self, key):
"""
:param str key: the key to clear rate limits for
"""
self.storage.delete(key)

def incr(self, key, expiry, elastic_expiry=False):
"""
increments the counter for a given rate limit key
:param str key: the key to increment
:param int expiry: amount in seconds for the key to expire in
:param bool elastic_expiry: whether to keep extending the rate limit
window every hit.
"""
if not self.call_memcached_func(
self.storage.add, key, 1, expiry, noreply=False
):
if elastic_expiry:
value, cas = self.storage.gets(key)
retry = 0
while (
not self.call_memcached_func(
self.storage.cas, key,
int(value or 0) + 1, cas, expiry
) and retry < self.MAX_CAS_RETRIES
):
value, cas = self.storage.gets(key)
retry += 1
self.call_memcached_func(
self.storage.set,
key + "/expires",
expiry + time.time(),
expire=expiry,
noreply=False
)
return int(value or 0) + 1
else:
return self.storage.incr(key, 1)
self.call_memcached_func(
self.storage.set,
key + "/expires",
expiry + time.time(),
expire=expiry,
noreply=False
)
return 1

def get_expiry(self, key):
"""
:param str key: the key to get the expiry for
"""
return int(float(self.storage.get(key + "/expires") or time.time()))

def check(self):
"""
check if storage is healthy
"""
try:
self.call_memcached_func(self.storage.get, 'limiter-check')
return True
except: # noqa
return False

0 comments on commit 1a06bde

Please sign in to comment.