Skip to content

Commit

Permalink
Some code cleanup.
Browse files Browse the repository at this point in the history
Also modified the logic around encoding a bit.
  • Loading branch information
argaen committed Oct 24, 2016
1 parent 6d6e90c commit 4627e4f
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 34 deletions.
10 changes: 5 additions & 5 deletions aiocache/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,23 @@ def set_policy(self, class_, *args, **kwargs):
self.policy = class_(self, *args, **kwargs)

@abc.abstractmethod
async def add(self, key, value, ttl=None): # pragma: no cover
async def add(self, key, value, ttl=None, dumps_fn=None): # pragma: no cover
pass

@abc.abstractmethod
async def get(self, key, default=None): # pragma: no cover
async def get(self, key, default=None, loads_fn=None): # pragma: no cover
pass

@abc.abstractmethod
async def multi_get(self, keys): # pragma: no cover
async def multi_get(self, keys, loads_fn=None): # pragma: no cover
pass

@abc.abstractmethod
async def set(self, key, value, ttl=None): # pragma: no cover
async def set(self, key, value, ttl=None, dumps_fn=None): # pragma: no cover
pass

@abc.abstractmethod
async def multi_set(self, pairs): # pragma: no cover
async def multi_set(self, pairs, dumps_fn=None): # pragma: no cover
pass

@abc.abstractmethod
Expand Down
8 changes: 6 additions & 2 deletions aiocache/backends/memcached.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ def __init__(self, *args, endpoint=None, port=None, loop=None, **kwargs):
self._loop = loop or asyncio.get_event_loop()
self.client = aiomcache.Client(self.endpoint, self.port, loop=self._loop)

@property
def _encoding(self):
return getattr(self.serializer, "encoding", "utf-8")

async def get(self, key, default=None, loads_fn=None):
"""
Get a value from the cache. Returns default if not found.
Expand All @@ -31,7 +35,7 @@ async def get(self, key, default=None, loads_fn=None):
value = await self.client.get(ns_key)

if value:
if isinstance(value, bytes):
if isinstance(value, bytes) and self._encoding:
value = bytes.decode(value)
await self.policy.post_get(key)

Expand All @@ -55,7 +59,7 @@ async def multi_get(self, keys, loads_fn=None):

decoded_values = []
for value in values:
if value is not None and isinstance(value, bytes):
if value is not None and isinstance(value, bytes) and self._encoding:
decoded_values.append(loads(bytes.decode(value)))
else:
decoded_values.append(loads(value))
Expand Down
16 changes: 8 additions & 8 deletions aiocache/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,49 +14,49 @@ def __init__(self, *args, endpoint=None, port=None, loop=None, **kwargs):
self._pool = None
self._loop = loop or asyncio.get_event_loop()

async def get(self, key, default=None, loads_fn=None, encoding=None):
@property
def _encoding(self):
return getattr(self.serializer, "encoding", "utf-8")

async def get(self, key, default=None, loads_fn=None):
"""
Get a value from the cache. Returns default if not found.
:param key: str
:param default: obj to return when key is not found
:param loads_fn: callable alternative to use as loads function
:param encoding: alternative encoding to use. Default is to use the self.serializer.encoding
:returns: obj deserialized
"""

loads = loads_fn or self.serializer.loads
encoding = encoding or getattr(self.serializer, "encoding", 'utf-8')
ns_key = self._build_key(key)

await self.policy.pre_get(key)

with await self._connect() as redis:
value = loads(await redis.get(ns_key, encoding=encoding))
value = loads(await redis.get(ns_key, encoding=self._encoding))

if value:
await self.policy.post_get(key)

return value or default

async def multi_get(self, keys, loads_fn=None, encoding=None):
async def multi_get(self, keys, loads_fn=None):
"""
Get a value from the cache. Returns default if not found.
:param key: str
:param loads_fn: callable alternative to use as loads function
:param encoding: alternative encoding to use. Default is to use the self.serializer.encoding
:returns: obj deserialized
"""
loads = loads_fn or self.serializer.loads
encoding = encoding or getattr(self.serializer, "encoding", 'utf-8')

for key in keys:
await self.policy.pre_get(key)

with await self._connect() as redis:
ns_keys = [self._build_key(key) for key in keys]
values = [loads(obj) for obj in await redis.mget(*ns_keys, encoding=encoding)]
values = [loads(obj) for obj in await redis.mget(*ns_keys, encoding=self._encoding)]

for key in keys:
await self.policy.post_get(key)
Expand Down
18 changes: 3 additions & 15 deletions aiocache/serializers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import abc
import logging


Expand All @@ -17,18 +16,7 @@
import pickle


class BaseSerializer(metaclass=abc.ABCMeta):

@abc.abstractmethod
def dumps(self, value):
pass

@abc.abstractmethod
def loads(self, value):
pass


class DefaultSerializer(BaseSerializer):
class DefaultSerializer:
"""
Dummy serializer that returns the same value passed both in serialize and
deserialize methods.
Expand All @@ -44,7 +32,7 @@ def loads(self, value):
return value


class PickleSerializer(BaseSerializer):
class PickleSerializer(DefaultSerializer):
"""
Transform data to bytes using pickle.dumps and pickle.loads to retrieve it back.
"""
Expand Down Expand Up @@ -72,7 +60,7 @@ def loads(self, value):
return pickle.loads(value)


class JsonSerializer(BaseSerializer):
class JsonSerializer(DefaultSerializer):
"""
Transform data to json string with json.dumps and json.loads to retrieve it back.
"""
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/backends/test_base_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ async def test_set_complex_type(self, cache, obj, serializer):
(MyType(), serializers.PickleSerializer),
(MyType().__dict__, serializers.JsonSerializer),
])
async def test_get_complex_type(self, redis_cache, obj, serializer):
redis_cache.serializer = serializer()
await redis_cache.set(pytest.KEY, obj)
assert await redis_cache.get(pytest.KEY) == obj
async def test_get_complex_type(self, cache, obj, serializer):
cache.serializer = serializer()
await cache.set(pytest.KEY, obj)
assert await cache.get(pytest.KEY) == obj

@pytest.mark.asyncio
async def test_get_set_alt_serializer_functions(self, cache):
Expand Down

0 comments on commit 4627e4f

Please sign in to comment.