Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement gets cas commands #33

Merged
merged 4 commits into from
Jan 26, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
93 changes: 88 additions & 5 deletions aiomcache/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def close(self):
@asyncio.coroutine
def _multi_get(self, conn, *keys):
# req - get <key> [<key> ...]\r\n
# resp - VALUE <key> <flags> <bytes> [<cas unique>]\r\n
# resp - VALUE <key> <flags> <bytes>\r\n
# <data block>\r\n (if exists)
# [...]
# END\r\n
Expand Down Expand Up @@ -118,6 +118,51 @@ def _multi_get(self, conn, *keys):
raise ClientException('received too many responses')
return [received.get(k, None) for k in keys]

@asyncio.coroutine
def _multi_gets(self, conn, *keys):
# req - get <key> [<key> ...]\r\n
# resp - VALUE <key> <flags> <bytes> <cas unique>\r\n
# <data block>\r\n (if exists)
# [...]
# END\r\n
if not keys:
return []

[self._validate_key(key) for key in keys]
if len(set(keys)) != len(keys):
raise ClientException('duplicate keys passed to multi_get')

conn.writer.write(b'gets ' + b' '.join(keys) + b'\r\n')

received = {}
line = yield from conn.reader.readline()

while line != b'END\r\n':
terms = line.split()

if len(terms) == 5 and terms[0] == b'VALUE': # exists
key = terms[1]
flags = int(terms[2])
length = int(terms[3])
cas = int(terms[4])

if flags != 0:
raise ClientException('received non zero flags')

val = (yield from conn.reader.readexactly(length+2))[:-2]
if key in received:
raise ClientException('duplicate results from server')

received[key] = val, cas
else:
raise ClientException('get failed', line)

line = yield from conn.reader.readline()

if len(received) > len(keys):
raise ClientException('received too many responses')
return [received.get(k, [None, None]) for k in keys]

@acquire
def delete(self, conn, key):
"""Deletes a key/value pair from the server.
Expand Down Expand Up @@ -146,6 +191,19 @@ def get(self, conn, key, default=None):
result = yield from self._multi_get(conn, key)
return (result[0] or default) if result else default

@acquire
def gets(self, conn, key, default=None):
"""Gets a single value from the server.

:param key: ``bytes``, is the key for the item being fetched
:param default: default value if there is no value.
:return: ``bytes``, is the data for this specified key.
"""
result = yield from self._multi_gets(conn, key)
if result and result[0][0] is not None:
return result[0]
return default, None

@acquire
def multi_get(self, conn, *keys):
"""Takes a list of keys and returns a list of values.
Expand Down Expand Up @@ -189,10 +247,13 @@ def stats(self, conn, args=None):

@asyncio.coroutine
def _storage_command(self, conn, command, key, value,
flags=0, exptime=0):
flags=0, exptime=0, cas=None):
# req - set <key> <flags> <exptime> <bytes> [noreply]\r\n
# <data block>\r\n
# resp - STORED\r\n (or others)
# req - set <key> <flags> <exptime> <bytes> <cas> [noreply]\r\n
# <data block>\r\n
# resp - STORED\r\n (or others)

# typically, if val is > 1024**2 bytes server returns:
# SERVER_ERROR object too large for cache\r\n
Expand All @@ -206,11 +267,14 @@ def _storage_command(self, conn, command, key, value,
raise ValidationException('exptime negative', exptime)

args = [str(a).encode('utf-8') for a in (flags, exptime, len(value))]
_cmd = b' '.join([command, key] + args) + b'\r\n'
cmd = _cmd + value + b'\r\n'
_cmd = b' '.join([command, key] + args)
if cas:
_cmd += b' ' + str(cas).encode('utf-8')
cmd = _cmd + b'\r\n' + value + b'\r\n'
resp = yield from self._execute_simple_command(conn, cmd)

if resp not in (const.STORED, const.NOT_STORED):
if resp not in (
const.STORED, const.NOT_STORED, const.EXISTS, const.NOT_FOUND):
raise ClientException('stats {} failed'.format(command), resp)
return resp == const.STORED

Expand All @@ -230,6 +294,25 @@ def set(self, conn, key, value, exptime=0):
conn, b'set', key, value, flags, exptime)
return resp

@acquire
def cas(self, conn, key, value, cas_token, exptime=0):
"""Sets a key to a value on the server
with an optional exptime (0 means don't auto-expire)
only if value hasn't change from first retrieval

:param key: ``bytes``, is the key of the item.
:param value: ``bytes``, data to store.
:param exptime: ``int``, is expiration time. If it's 0, the
item never expires.
:param cas_token: ``int``, unique cas token retrieve from previous
``gets``
:return: ``bool``, True in case of success.
"""
flags = 0 # TODO: fix when exception removed
resp = yield from self._storage_command(
conn, b'cas', key, value, flags, exptime, cas=cas_token)
return resp

@acquire
def add(self, conn, key, value, exptime=0):
"""Store this data, but only if the server *doesn't* already
Expand Down
1 change: 1 addition & 0 deletions aiomcache/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
NOT_FOUND = b'NOT_FOUND'
DELETED = b'DELETED'
VERSION = b'VERSION'
EXISTS = b'EXISTS'
OK = b'OK'
41 changes: 40 additions & 1 deletion tests/commands_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_set_get(mcache, loop):
test_value = yield from mcache.get(b'not:' + key)
assert test_value is None
test_value = yield from mcache.get(b'not:' + key, default=value)
assert test_value is value
assert test_value == value

with mock.patch.object(mcache, '_execute_simple_command') as patched:
fut = asyncio.Future(loop=loop)
Expand All @@ -59,6 +59,24 @@ def test_set_get(mcache, loop):
yield from mcache.set(key, value)


@pytest.mark.run_loop
def test_gets(mcache, loop):
key, value = b'key:set', b'1'
yield from mcache.set(key, value)

test_value, cas = yield from mcache.gets(key)
assert test_value == value
assert isinstance(cas, int)

test_value, cas = yield from mcache.gets(b'not:' + key)
assert test_value is None
assert cas is None

test_value, cas = yield from mcache.gets(b'not:' + key, default=value)
assert test_value == value
assert cas is None


@pytest.mark.run_loop
def test_multi_get(mcache):
key1, value1 = b'key:multi_get:1', b'1'
Expand Down Expand Up @@ -108,6 +126,27 @@ def test_set_errors(mcache):
yield from mcache.set(key, value, exptime=3.14)


@pytest.mark.run_loop
def test_gets_cas(mcache, loop):
key, value = b'key:set', b'1'
yield from mcache.set(key, value)

test_value, cas = yield from mcache.gets(key)

stored = yield from mcache.cas(key, value, cas)
assert stored is True

stored = yield from mcache.cas(key, value, cas)
assert stored is False


@pytest.mark.run_loop
def test_cas_missing(mcache, loop):
key, value = b'key:set', b'1'
stored = yield from mcache.cas(key, value, 123)
assert stored is False


@pytest.mark.run_loop
def test_add(mcache):
key, value = b'key:add', b'1'
Expand Down