## References

- https://medium.com/@sahiljadon/rate-limiting-using-redis-lists-and-sorted-sets-9b42bc192222
- https://redislabs.com/redis-best-practices/basic-rate-limiting/
- https://www.binpress.com/rate-limiting-with-redis-1/
- https://blog.callr.tech/rate-limiting-for-distributed-systems-with-redis-and-lua/
- https://brandur.org/redis-cluster
- https://engineering.classdojo.com/blog/2015/02/06/rolling-rate-limiter/

In [1]:
import time

import redis

In [2]:
r = redis.Redis(decode_responses=True)

In [3]:
r.ping()

True

In [4]:
r.set("hello", "world")

True

In [6]:
# Basic lua script.
lua = "return redis.call('GET', KEYS[1])"
r.eval(lua, 1, "hello")

'world'

In [7]:
# Alternative with register_script method.
lua = """
local value = redis.call('GET', KEYS[1])
value = tonumber(value)
return value * ARGV[1]"""
multiply = r.register_script(lua)

r.set("foo", 2)
multiply(keys=["foo"], args=[2])

4

## Allow only N API requests per minute

In [17]:
class RateLimiter:
    def __init__(self, conn, limit=5):
        self.conn = conn
        self.nkeys = 1
        self.limit = limit
        self.script = """
            if redis.call('EXISTS', KEYS[1]) == 0 then
                redis.call('SETEX', KEYS[1], 60, 0)
            end
            redis.call('INCR', KEYS[1])
            if tonumber(redis.call('GET', KEYS[1])) <= tonumber(ARGV[1]) then
                return 'ok'
            else
                return 'limit exceeded'
            end
        """

    def allow(self, ip):
        return self.conn.eval(self.script, self.nkeys, ip, self.limit) == "ok"

In [18]:
ratelimit = RateLimiter(r)

In [24]:
ratelimit.allow(1)

False

In [55]:
r.get("1")

'1'

## Allow only N API requests per minute on a running window

In the previous redis version, passing `time` to the script is always the recommended approach, as mentioned [here](https://redis.io/docs/manual/programmability/eval-intro/#:~:text=Acts%20such%20as%20using%20the%20system%20time%2C%20calling%20Redis%20commands%20that%20return%20random%20values%20(e.g.%2C%20RANDOMKEY)%2C%20or%20using%20Lua%27s%20random%20number%20generator%2C%20could%20result%20in%20scripts%20that%20will%20not%20evaluate%20consistently.). Call to `redis.call("TIME")` is not recommended.

However, in the newer version, you can do that.

In [55]:
# Returns seconds
int(time.time())

1673341712

In [62]:
class RateLimiter:
    def __init__(self, conn, n=5):
        self.conn = conn
        self.n = n
        self.lua = self.conn.register_script(
            f"""
            -- ARGV[1]: The limit.
            -- ARGV[2]: The current timestamp in seconds.
            -- KEYS[1]: The key to rate limit, e.g. clientIP + userID/sessionID
            local count = tonumber(redis.call('LLEN', KEYS[1]))
            if count < tonumber(ARGV[1]) then
                redis.call('LPUSH', KEYS[1], now)
                return 'ok'
            else
                local now = redis.call('TIME')[1]
                local time = tonumber(redis.call('LINDEX', KEYS[1], -1))
                if now - time < 60 then
                    return 'limit exceeded'
                else
                    -- Push the timestamp to the list.
                    redis.call('LPUSH', KEYS[1], now)
                    
                    -- Remove previous item in the list.
                    redis.call('RPOP', KEYS[1])
                    return 'ok'
                end
            end
        """
        )

    def allow(self, ip):
        return self.lua(keys=[ip], args=[self.n]) == "ok"

In [63]:
ratelimit = RateLimiter(r, 10)

In [75]:
ratelimit.allow(1)

False

In [76]:
r.llen("1")

10

## Rate limiting using sorted set

In [81]:
class RateLimiter:
    def __init__(self, conn, n=5):
        self.conn = conn
        self.n = n
        self.lua = self.conn.register_script(
            f"""
            -- ARGV[1]: The current limit.
            -- KEYS[1]: The key to rate limit, e.g. clientIP + userID/sessionID.
            local limit = tonumber(ARGV[1])
            local now = redis.call('TIME')
            -- The first argument is seconds, the second is microseconds.
            -- Convert them to microseconds.
            local now_ms = math.floor(now[1] * 1000 + now[2] / 1000)
            
            -- Delete all keys that are older than 1 minute ago.
            redis.call('ZREMRANGEBYSCORE', KEYS[1], 0, now_ms - 60*1000)
            
            -- Find the number of remaining tokens left. 
            if tonumber(redis.call('ZCARD', KEYS[1])) < limit then
                redis.call('ZADD', KEYS[1], now_ms, now_ms)
                return 'ok'
            else
                return 'limit exceeded'
            end
        """
        )

    def allow(self, ip):
        # We need millisecond precisions - else the seconds will be counted as 1 item in the sorted set.
        return self.lua(keys=[ip], args=[self.n]) == "ok"

In [82]:
ratelimit = RateLimiter(r)
for i in range(6):
    print(ratelimit.allow("0.0.0.0"))

True
True
True
True
True
False


In [83]:
r.zrange("0.0.0.0", 0, -1, withscores=True)

[('1673342160207', 1673342160207.0),
 ('1673342160221', 1673342160221.0),
 ('1673342160223', 1673342160223.0),
 ('1673342160236', 1673342160236.0),
 ('1673342160247', 1673342160247.0)]