Skip to content

Commit

Permalink
fix memcache get_expiry
Browse files Browse the repository at this point in the history
  • Loading branch information
alisaifee committed May 25, 2014
1 parent aa791fc commit edb5af1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 5 deletions.
8 changes: 5 additions & 3 deletions flask_limiter/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,16 +308,18 @@ def incr(self, key, expiry, elastic_expiry=False):
and retry < self.MAX_CAS_RETRIES
):
value, cas = self.storage.gets(key)
self.storage.add(key + "/expires", expiry)
retry += 1
self.storage.set(key + "/expires", expiry + time.time(), expire=expiry, noreply=False)
return int(value) + 1
else:
return self.storage.incr(key, 1)
return 1
else:
self.storage.set(key + "/expires", expiry + time.time(), expire=expiry, noreply=False)
return 1

def get_expiry(self, key):
"""
:param str key: the key to get the expiry for
"""
return int(self.storage.get(key + "/expires"))
return int(float(self.storage.get(key + "/expires") or time.time()))

36 changes: 34 additions & 2 deletions tests/test_flask_ext.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
"""
"""
import time

import logging
import unittest
from flask import Flask, Blueprint, request, current_app
import hiro
import mock
import time
from datetime import datetime
from flask.ext.limiter.extension import Limiter
import redis
import pymemcache.client


class FlaskExtTests(unittest.TestCase):
def setUp(self):
redis.Redis().flushall()
pymemcache.client.Client(('localhost', 11211)).flush_all()

def test_combined_rate_limits(self):
app = Flask(__name__)
Expand Down Expand Up @@ -409,3 +411,33 @@ def t():
resp.headers.get('X-RateLimit-Reset'),
str(int(time.time() + 1))
)

def test_headers_fixed_window_memcached(self):

app = Flask(__name__)
app.config["RATELIMIT_STRATEGY"] = "fixed-window"
app.config["RATELIMIT_STORAGE_URL"] = "memcached://localhost:11211"
limiter = Limiter(app, global_limits=["10/minute"], headers_enabled=True)

@app.route("/t1")
@limiter.limit("10/second; 20per minute")
def t():
return "test"

with app.test_client() as cli:
start = time.time()
for i in range(21):
resp = cli.get("/t1")
time.sleep(0.1)
self.assertEqual(
resp.headers.get('X-RateLimit-Limit'),
'20'
)
self.assertEqual(
resp.headers.get('X-RateLimit-Remaining'),
'0'
)
self.assertEqual(
resp.headers.get('X-RateLimit-Reset'),
str(int(start) + 60)
)

0 comments on commit edb5af1

Please sign in to comment.