Skip to content

Commit

Permalink
Merge pull request #152 from alisaifee/fix-incorrect-over-limit-acuir…
Browse files Browse the repository at this point in the history
…e_entry

Disallow acquiring > limit in moving window
  • Loading branch information
alisaifee committed Jan 16, 2023
2 parents 3234791 + 2c988c3 commit 02c9d99
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 23 deletions.
3 changes: 3 additions & 0 deletions limits/aio/storage/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ async def acquire_entry(
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
if amount > limit:
return False

self.events.setdefault(key, [])
await self.__schedule_expiry()
timestamp = time.time()
Expand Down
3 changes: 3 additions & 0 deletions limits/aio/storage/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ async def acquire_entry(
"""
await self.create_indices()

if amount > limit:
return False

timestamp = time.time()
try:
updates: Dict[str, Any] = { # type: ignore
Expand Down
12 changes: 9 additions & 3 deletions limits/resources/redis/lua_scripts/acquire_moving_window.lua
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
local amount = tonumber(ARGV[4])
local entry = redis.call('lindex', KEYS[1], tonumber(ARGV[2]) - amount)
local timestamp = tonumber(ARGV[1])
local limit = tonumber(ARGV[2])
local expiry = tonumber(ARGV[3])
local amount = tonumber(ARGV[4])

if amount > limit then
return false
end

local entry = redis.call('lindex', KEYS[1], limit - amount)


if entry and tonumber(entry) >= timestamp - expiry then
return false
end
local limit = tonumber(ARGV[2])
local entries= {}
for i=1, amount do
entries[i] = timestamp
Expand Down
3 changes: 3 additions & 0 deletions limits/storage/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> b
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
if amount > limit:
return False

self.events.setdefault(key, [])
self.__schedule_expiry()
timestamp = time.time()
Expand Down
3 changes: 3 additions & 0 deletions limits/storage/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> b
:param expiry: expiry of the entry
:param amount: the number of entries to acquire
"""
if amount > limit:
return False

timestamp = time.time()
try:
updates: Dict[str, Any] = { # type: ignore
Expand Down
23 changes: 23 additions & 0 deletions tests/aio/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,29 @@ async def test_expiry_acquire_entry(self, uri, args, expected_instance, fixture)
time.sleep(1.1)
assert await storage.get(limit.key_for()) == 0

async def test_incr_custom_amount(self, uri, args, expected_instance, fixture):
storage = storage_from_string(uri, **args)
limit = RateLimitItemPerMinute(1)
assert 1 == await storage.incr(limit.key_for(), limit.get_expiry(), amount=1)
assert 11 == await storage.incr(limit.key_for(), limit.get_expiry(), amount=10)

async def test_acquire_entry_custom_amount(
self, uri, args, expected_instance, fixture
):
if not issubclass(expected_instance, MovingWindowSupport):
pytest.skip("%s does not support acquire entry" % expected_instance)
storage = storage_from_string(uri, **args)
limit = RateLimitItemPerMinute(10)
assert not await storage.acquire_entry(
limit.key_for(), limit.amount, limit.get_expiry(), amount=11
)
assert await storage.acquire_entry(
limit.key_for(), limit.amount, limit.get_expiry(), amount=1
)
assert not await storage.acquire_entry(
limit.key_for(), limit.amount, limit.get_expiry(), amount=10
)

async def test_storage_check(self, uri, args, expected_instance, fixture):
assert await (storage_from_string(uri, **args)).check()

Expand Down
26 changes: 16 additions & 10 deletions tests/aio/test_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ async def test_fixed_window_multiple_cost(self, uri, args, fixture):
storage = storage_from_string(uri, **args)
limiter = FixedWindowRateLimiter(storage)
limit = RateLimitItemPerMinute(10, 2)
assert await limiter.hit(limit, cost=5)
assert (await limiter.get_window_stats(limit)).remaining == 5
assert not await limiter.hit(limit, "k1", cost=11)
assert await limiter.hit(limit, "k2", cost=5)
assert (await limiter.get_window_stats(limit, "k2")).remaining == 5
assert not await limiter.hit(limit, "k2", cost=6)

@async_all_storage
@fixed_start
Expand All @@ -79,10 +81,12 @@ async def test_fixed_window_with_elastic_expiry_multiple_cost(
storage = storage_from_string(uri, **args)
limiter = FixedWindowElasticExpiryRateLimiter(storage)
limit = RateLimitItemPerSecond(10, 2)
assert not await limiter.hit(limit, "k1", cost=11)
async with async_window(0) as (start, end):
assert await limiter.hit(limit, cost=5)
assert (await limiter.get_window_stats(limit)).remaining == 5
assert (await limiter.get_window_stats(limit)).reset_time == end + 2
assert await limiter.hit(limit, "k2", cost=5)
assert (await limiter.get_window_stats(limit, "k2")).remaining == 5
assert (await limiter.get_window_stats(limit, "k2")).reset_time == end + 2
assert not await limiter.hit(limit, "k2", cost=6)

@async_moving_window_storage
async def test_moving_window(self, uri, args, fixture):
Expand Down Expand Up @@ -121,16 +125,18 @@ async def test_moving_window_multiple_cost(self, uri, args, fixture):
limiter = MovingWindowRateLimiter(storage)
limit = RateLimitItemPerSecond(10, 2)

assert not await limiter.hit(limit, "k1", cost=11)
# 5 hits in the first 100ms
async with async_window(0.1):
assert await limiter.hit(limit, cost=5)
assert await limiter.hit(limit, "k2", cost=5)
# 5 hits in the last 100ms
async with async_window(2, delay=1.8):
assert all([await limiter.hit(limit) for i in range(5)])
assert all([await limiter.hit(limit, "k2") for i in range(5)])
# 11th fails
assert not await limiter.hit(limit)
assert all([await limiter.hit(limit) for i in range(5)])
assert (await limiter.get_window_stats(limit)).remaining == 0
assert not await limiter.hit(limit, "k2")
assert all([await limiter.hit(limit, "k2") for i in range(5)])
assert (await limiter.get_window_stats(limit, "k2")).remaining == 0
assert not await limiter.hit(limit, "k2", cost=2)

@async_moving_window_storage
async def test_moving_window_varying_cost(self, uri, args, fixture):
Expand Down
21 changes: 21 additions & 0 deletions tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,27 @@ def test_expiry_acquire_entry(self, uri, args, expected_instance, fixture):
time.sleep(1.1)
assert storage.get(limit.key_for()) == 0

def test_incr_custom_amount(self, uri, args, expected_instance, fixture):
storage = storage_from_string(uri, **args)
limit = RateLimitItemPerMinute(1)
assert 1 == storage.incr(limit.key_for(), limit.get_expiry(), amount=1)
assert 11 == storage.incr(limit.key_for(), limit.get_expiry(), amount=10)

def test_acquire_entry_custom_amount(self, uri, args, expected_instance, fixture):
if not issubclass(expected_instance, MovingWindowSupport):
pytest.skip("%s does not support acquire entry" % expected_instance)
storage = storage_from_string(uri, **args)
limit = RateLimitItemPerMinute(10)
assert not storage.acquire_entry(
limit.key_for(), limit.amount, limit.get_expiry(), amount=11
)
assert storage.acquire_entry(
limit.key_for(), limit.amount, limit.get_expiry(), amount=1
)
assert not storage.acquire_entry(
limit.key_for(), limit.amount, limit.get_expiry(), amount=10
)

def test_storage_check(self, uri, args, expected_instance, fixture):
assert storage_from_string(uri, **args).check()

Expand Down
26 changes: 16 additions & 10 deletions tests/test_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ def test_fixed_window_multiple_cost(self, uri, args, fixture):
storage = storage_from_string(uri, **args)
limiter = FixedWindowRateLimiter(storage)
limit = RateLimitItemPerMinute(10, 2)
assert limiter.hit(limit, cost=5)
assert limiter.get_window_stats(limit).remaining == 5
assert not limiter.hit(limit, "k1", cost=11)
assert limiter.hit(limit, "k2", cost=5)
assert limiter.get_window_stats(limit, "k2").remaining == 5
assert not limiter.hit(limit, "k2", cost=6)

@all_storage
@fixed_start
Expand All @@ -71,10 +73,12 @@ def test_fixed_window_with_elastic_expiry_multiple_cost(self, uri, args, fixture
storage = storage_from_string(uri, **args)
limiter = FixedWindowElasticExpiryRateLimiter(storage)
limit = RateLimitItemPerSecond(10, 2)
assert not limiter.hit(limit, "k1", cost=11)
with window(0) as (start, end):
assert limiter.hit(limit, cost=5)
assert limiter.get_window_stats(limit).remaining == 5
assert limiter.get_window_stats(limit).reset_time == end + 2
assert limiter.hit(limit, "k2", cost=5)
assert limiter.get_window_stats(limit, "k2").remaining == 5
assert limiter.get_window_stats(limit, "k2").reset_time == end + 2
assert not limiter.hit(limit, "k2", cost=6)

@moving_window_storage
def test_moving_window_empty_stats(self, uri, args, fixture):
Expand Down Expand Up @@ -105,18 +109,20 @@ def test_moving_window_multiple_cost(self, uri, args, fixture):
limiter = MovingWindowRateLimiter(storage)
limit = RateLimitItemPerSecond(10, 2)

assert not limiter.hit(limit, "k1", cost=11)
# 5 hits in the first 100ms
with window(0.1):
limiter.hit(limit, cost=5)
limiter.hit(limit, "k2", cost=5)
# 5 hits in the last 100ms
with window(2, delay=1.8):
assert all(limiter.hit(limit) for i in range(5))
assert all(limiter.hit(limit, "k2") for i in range(5))
# 11th fails
assert not limiter.hit(limit)
assert not limiter.hit(limit, "k2")

# 5 more succeed since there were only 5 in the last 2 seconds
assert all([limiter.hit(limit) for i in range(5)])
assert limiter.get_window_stats(limit)[1] == 0
assert all([limiter.hit(limit, "k2") for i in range(5)])
assert limiter.get_window_stats(limit, "k2")[1] == 0
assert not limiter.hit(limit, "k2", cost=2)

@moving_window_storage
def test_moving_window_varying_cost(self, uri, args, fixture):
Expand Down

0 comments on commit 02c9d99

Please sign in to comment.