Skip to content

Commit

Permalink
Merge e716157 into df190a5
Browse files Browse the repository at this point in the history
  • Loading branch information
c-w committed Oct 10, 2016
2 parents df190a5 + e716157 commit 92104b7
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 4 deletions.
3 changes: 0 additions & 3 deletions README.md
Expand Up @@ -49,6 +49,3 @@ Check for any open issues, or open one yourself! All contributions are appreciat

# Tests
`nosetests`

# Note
When implementing the caching decorator on methods with objects as parameters, please be sure to override the `__str__` method on that object. The default behavior for most objects in Python is to return a string such as `<Object instance at 0x12345678>`. Since this is in the signature of the function, it will cache using that address in memory and will result in cache misses everytime that object is changed. This is especially apparent while caching class methods since the first paramter is always the object itself (`self`).
24 changes: 23 additions & 1 deletion redis_cache/redis_cache.py
Expand Up @@ -58,7 +58,7 @@ def _default_signature_generator(*args, **kwargs):
:return: arg1,...argn,kwarg1=kwarg1,...kwargn=kwargn
"""
# Join regular arguments together with commas
parsed_args = ",".join(map(lambda x: str(x), args[1]))
parsed_args = ",".join(map(_argument_to_string, args[1]))
# Join keyword arguments together with `=` and commas
parsed_kwargs = ",".join(
map(lambda x: '%s=%s' % (x, str(kwargs[x])), kwargs)
Expand All @@ -68,3 +68,25 @@ def _default_signature_generator(*args, **kwargs):
lambda x: x != '', [parsed_args, parsed_kwargs]
)
return ','.join(parsed)


def _argument_to_string(arg):
if arg.__class__.__module__ == '__builtin__':
return str(arg)

# for backwards-compatibility: if the object defines a custom __str__ method
# use it instead of the default approach based on the class-name
string_representation = str(arg)
if string_representation != object.__str__(arg):
return string_representation

try:
instance_namespace = arg.__class__.__name__
instance_state = sorted((field, _argument_to_string(value))
for field, value in vars(arg).items())
return '{}_{}'.format(instance_namespace, instance_state)

except TypeError:
# some non-primitive types (e.g., defaultdict) don't have a __dict__ so
# we need to have a fall-back
return str(arg)
17 changes: 17 additions & 0 deletions tests/fakes.py
@@ -0,0 +1,17 @@
class FakeRedisClient(object):
"""
In-memory fake implementation of RedisClient
"""
def __init__(self):
self._cache = {}

def get(self, key):
return self._cache.get(key, '')

def set(self, key, value, **kwargs):
self._cache[key] = value
return ''

def setex(self, key, value, expiration):
self._cache[key] = value
return ''
99 changes: 99 additions & 0 deletions tests/test_redis_cache.py
@@ -1,8 +1,10 @@
from redis_cache.redis_cache import \
RedisCache, RedisException, DEFAULT_EXPIRATION
from fakes import FakeRedisClient
from unittest import TestCase
from mock import Mock, patch
from inputs import SimpleObject
import collections
import pickle


Expand Down Expand Up @@ -247,3 +249,100 @@ def __str__(self):
mock_client_object.assert_called_once_with(self.address, self.port)
expected_hash = str(hash('cache_this_method' + str(test_class)))
mock_client.get.assert_called_once_with(expected_hash)

def test_cache_on_class_without_str_for_function_with_simple_args(self):
"""
Tests a cache hit when invoking a method that takes a primitive type as
argument on an object that does not implement a __str__ method
"""
redis_cache = RedisCache(self.address, self.port)
redis_cache.redis_client = FakeRedisClient()

class TestClass(object):
def __init__(self):
self.call_count = collections.defaultdict(int)

@redis_cache.cache()
def echo(self, parameter):
self.call_count[parameter] += 1
return parameter

primitive_argument = 'cache hit test'

instance1 = TestClass()
instance2 = TestClass()
instance3 = TestClass()

value1 = instance1.echo(primitive_argument)
value2 = instance2.echo(primitive_argument)
value3 = instance3.echo(primitive_argument)

self.assertTrue(value1 == value2 == value3 == primitive_argument)
self.assertEqual(instance1.call_count[primitive_argument], 1)
self.assertEqual(instance2.call_count[primitive_argument], 0)
self.assertEqual(instance3.call_count[primitive_argument], 0)

def test_cache_on_class_without_str_for_function_with_complex_args(self):
"""
Tests a cache hit when invoking a method that takes a complex type as
argument on an object that does not implement a __str__ method
"""
redis_cache = RedisCache(self.address, self.port)
redis_cache.redis_client = FakeRedisClient()

class TestClass(object):
def __init__(self):
self.some_method_call_count = collections.defaultdict(int)

@redis_cache.cache()
def some_method(self, parameter):
self.some_method_call_count[parameter] += 1
return parameter

complex_argument = SimpleObject('test', 42)

instance1 = TestClass()
instance2 = TestClass()
instance3 = TestClass()

value1 = instance1.some_method(complex_argument)
value2 = instance2.some_method(complex_argument)
value3 = instance3.some_method(complex_argument)

self.assertTrue(value1 == value2 == value3 == complex_argument)
self.assertEqual(instance1.some_method_call_count[complex_argument], 1)
self.assertEqual(instance2.some_method_call_count[complex_argument], 0)
self.assertEqual(instance3.some_method_call_count[complex_argument], 0)

def test_cache_on_stateful_class_without_str(self):
"""
Tests a cache hit when invoking a method on a stateful object.
"""
redis_cache = RedisCache(self.address, self.port)
redis_cache.redis_client = FakeRedisClient()

class TestClass(object):
def __init__(self, state):
self.state = state

@redis_cache.cache()
def some_method(self):
return self.state

state1 = 'some state'
state2 = SimpleObject('foo', 2)

state1_instance1 = TestClass(state1)
state1_instance2 = TestClass(state1)
state2_instance1 = TestClass(state2)
state2_instance2 = TestClass(state2)

value1 = state1_instance1.some_method()
value2 = state1_instance1.some_method()
value3 = state1_instance2.some_method()
value4 = state2_instance1.some_method()
value5 = state2_instance2.some_method()
value6 = state2_instance1.some_method()

self.assertTrue(value1 == value2 == value3 == state1)
self.assertTrue(value4 == value5 == value6 == state2)

0 comments on commit 92104b7

Please sign in to comment.