Skip to content

Commit

Permalink
add redis_pool param, run py 3.6 tests, use scan, and create RateLimi…
Browse files Browse the repository at this point in the history
…ter class

* Allow to configure redis pool.
* Run tests on python 3.6
* Use scan instead of keys.
* Add RateLimiter.
* Update README.rst
  • Loading branch information
kramarz authored and victor-torres committed Nov 30, 2017
1 parent 90f45fa commit 2e0ede6
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 4 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Expand Up @@ -2,6 +2,7 @@ language: python
python:
- "2.7"
- "3.5"
- "3.6"
script: "make test"
branches:
only:
Expand Down
14 changes: 14 additions & 0 deletions README.rst
Expand Up @@ -46,3 +46,17 @@ Example: 100 requests per hour
>>> except TooManyRequests:
>>> return '429 Too Many Requests'
>>>
You can also setup factory to use later.
Example:

.. code-block:: python
>>> from redis_rate_limit import RateLimiter, TooManyRequests
>>> limiter = RateLimiter(resource='users_list', max_requests=100, expire=3600)
>>> try:
>>> with limiter.limit(client='192.168.0.10'):
>>> return '200 OK'
>>> except TooManyRequests:
>>> return '429 Too Many Requests'
>>>
36 changes: 33 additions & 3 deletions redis_rate_limit/__init__.py
Expand Up @@ -42,7 +42,7 @@ class RateLimit(object):
This class offers an abstraction of a Rate Limit algorithm implemented on
top of Redis >= 2.6.0.
"""
def __init__(self, resource, client, max_requests, expire=None):
def __init__(self, resource, client, max_requests, expire=None, redis_pool=REDIS_POOL):
"""
Class initialization method checks if the Rate Limit algorithm is
actually supported by the installed Redis version and sets some
Expand All @@ -54,8 +54,10 @@ def __init__(self, resource, client, max_requests, expire=None):
:param client: client identifier string (i.e. ‘192.168.0.10’)
:param max_requests: integer (i.e. ‘10’)
:param expire: seconds to wait before resetting counters (i.e. ‘60’)
:param redis_pool: instance of redis.ConnectionPool.
Default: ConnectionPool(host='127.0.0.1', port=6379, db=0)
"""
self._redis = Redis(connection_pool=REDIS_POOL)
self._redis = Redis(connection_pool=redis_pool)
if not self._is_rate_limit_supported():
raise RedisVersionNotSupported()

Expand Down Expand Up @@ -122,6 +124,34 @@ def _reset(self):
"""
Deletes all keys that start with ‘rate_limit:’.
"""
for rate_limit_key in self._redis.keys('rate_limit:*'):
matching_keys = self._redis.scan_iter(match='{0}*'.format('rate_limit:*'))
for rate_limit_key in matching_keys:
self._redis.delete(rate_limit_key)


class RateLimiter(object):
def __init__(self, resource, max_requests, expire=None, redis_pool=REDIS_POOL):
"""
Rate limit factory. Checks if RateLimit is supported when limit is called.
:param resource: resource identifier string (i.e. ‘user_pictures’)
:param max_requests: integer (i.e. ‘10’)
:param expire: seconds to wait before resetting counters (i.e. ‘60’)
:param redis_pool: instance of redis.ConnectionPool.
Default: ConnectionPool(host='127.0.0.1', port=6379, db=0)
"""
self.resource = resource
self.max_requests = max_requests
self.expire = expire
self.redis_pool = redis_pool

def limit(self, client):
"""
:param client: client identifier string (i.e. ‘192.168.0.10’)
"""
return RateLimit(
resource=self.resource,
client=client,
max_requests=self.max_requests,
expire=self.expire,
redis_pool=self.redis_pool,
)
23 changes: 22 additions & 1 deletion tests/rate_limit_test.py
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
import unittest
import time
from redis_rate_limit import RateLimit, TooManyRequests
from redis_rate_limit import RateLimit, RateLimiter, TooManyRequests


class TestRedisRateLimit(unittest.TestCase):
Expand Down Expand Up @@ -64,6 +64,27 @@ def test_not_expired(self):
with self.rate_limit:
pass

def test_limit_10_using_rate_limiter(self):
"""
Should raise TooManyRequests Exception when trying to increment for the
eleventh time.
"""
self.rate_limit = RateLimiter(resource='test', max_requests=10,
expire=2).limit(client='localhost')
self.assertEqual(self.rate_limit.get_usage(), 0)
self.assertEqual(self.rate_limit.has_been_reached(), False)

self._make_10_requests()
self.assertEqual(self.rate_limit.get_usage(), 10)
self.assertEqual(self.rate_limit.has_been_reached(), True)

with self.assertRaises(TooManyRequests):
with self.rate_limit:
pass

self.assertEqual(self.rate_limit.get_usage(), 11)
self.assertEqual(self.rate_limit.has_been_reached(), True)


if __name__ == '__main__':
unittest.main()
Expand Down

0 comments on commit 2e0ede6

Please sign in to comment.