## References

- https://medium.com/@sahiljadon/rate-limiting-using-redis-lists-and-sorted-sets-9b42bc192222
- https://redislabs.com/redis-best-practices/basic-rate-limiting/
- https://www.binpress.com/rate-limiting-with-redis-1/
- https://blog.callr.tech/rate-limiting-for-distributed-systems-with-redis-and-lua/
- https://brandur.org/redis-cluster
- https://engineering.classdojo.com/blog/2015/02/06/rolling-rate-limiter/

In [1]:
import redis
import time

In [41]:
r = redis.Redis(password=123456, decode_responses=True)

In [42]:
r.ping()

True

In [43]:
r.set('hello', 'world')

True

In [44]:
# Basic lua script.
lua = "return redis.call('GET', KEYS[1])"
r.eval(lua, 1, 'hello')

'world'

## Allow only N API requests per minute

In [45]:
class RateLimiter:
    def __init__(self, conn):
        self.conn = conn
        self.script = '''
            if redis.call('EXISTS', KEYS[1]) == 0 then
                redis.call('SETEX', KEYS[1], 60, 0)
            end
            redis.call('INCR', KEYS[1])
            if tonumber(redis.call('GET', KEYS[1])) <= 5 then
                return 'ok'
            else
                return 'limit exceeded'
            end
        '''
    
    def allow(self, ip):
        return self.conn.eval(self.script, 1, ip) == 'ok'

In [46]:
ratelimit = RateLimiter(r)

In [54]:
ratelimit.allow(1)

True

In [55]:
r.get('1')

'1'

## Allow only N API requests per minute on a running window

In [117]:
class RateLimiter:
    def __init__(self, conn, n=5):
        self.conn = conn
        self.lua = self.conn.register_script(f'''
            -- ARGV[1]: The current timestamp in seconds.
            -- KEYS[1]: The key to rate limit, e.g. clientIP + userID/sessionID
            local count = tonumber(redis.call('LLEN', KEYS[1]))
            if count < {n} then
                redis.call('LPUSH', KEYS[1], ARGV[1])
                return 'ok'
            else
                local time = tonumber(redis.call('LINDEX', KEYS[1], -1))
                if ARGV[1] - time < 60 then
                    return 'limit exceeded'
                else
                    -- Push the timestamp to the list.
                    redis.call('LPUSH', KEYS[1], ARGV[1])
                    
                    -- Remove previous item in the list.
                    redis.call('RPOP', KEYS[1])
                    return 'ok'
                end
            end
        ''')
    
    def allow(self, ip):
        return self.lua(keys=[ip], 
                        args=[int(time.time())]) == 'ok'

In [127]:
ratelimit = RateLimiter(r, 10)

In [130]:
ratelimit.allow(1)

True

In [131]:
r.llen('1')

7

## Rate limiting using sorted set

In [178]:
class RateLimiter:
    def __init__(self, conn, n=5):
        self.conn = conn
        self.lua = self.conn.register_script(f'''
            -- ARGV[1]: The current timestamp in seconds.
            -- KEYS[1]: The key to rate limit, e.g. clientIP + userID/sessionID.
            
            -- Delete all keys that are older than 1 minute ago.
            redis.call('ZREMRANGEBYSCORE', KEYS[1], 0, ARGV[1] - 60*1000)
            
            -- Find the number of remaining tokens left. 
            if tonumber(redis.call('ZCARD', KEYS[1])) < {n} then
                redis.call('ZADD', KEYS[1], ARGV[1], ARGV[1])
                return 'ok'
            else
                return 'limit exceeded'
            end
        ''')
        
    def allow(self, ip):
        # We need millisecond precisions - else the seconds will be counted as 1 item in the sorted set.
        return self.lua(keys=[ip], 
                        args=[int(time.time() * 1000)]) == 'ok'

In [209]:
ratelimit = RateLimiter(r)
for i in range(6):
    print(ratelimit.allow('0.0.0.0'))

True
True
True
True
True
False


In [210]:
r.zrange('0.0.0.0', 0, -1, withscores=True)

[('1566309942739', 1566309942739.0),
 ('1566309942784', 1566309942784.0),
 ('1566309942815', 1566309942815.0),
 ('1566309942845', 1566309942845.0),
 ('1566309942862', 1566309942862.0)]