From e716157b0bf7542d43c830cdad9aadd67fd842c9 Mon Sep 17 00:00:00 2001 From: Clemens Wolff Date: Sun, 9 Oct 2016 22:52:36 -0700 Subject: [PATCH] Remove the need for a custom __str__ method Resolves #5 --- README.md | 3 -- redis_cache/redis_cache.py | 24 ++++++++- tests/fakes.py | 17 +++++++ tests/test_redis_cache.py | 99 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 139 insertions(+), 4 deletions(-) create mode 100644 tests/fakes.py diff --git a/README.md b/README.md index 1543447..436d15f 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,3 @@ Check for any open issues, or open one yourself! All contributions are appreciat # Tests `nosetests` - -# Note -When implementing the caching decorator on methods with objects as parameters, please be sure to override the `__str__` method on that object. The default behavior for most objects in Python is to return a string such as ``. Since this is in the signature of the function, it will cache using that address in memory and will result in cache misses everytime that object is changed. This is especially apparent while caching class methods since the first paramter is always the object itself (`self`). diff --git a/redis_cache/redis_cache.py b/redis_cache/redis_cache.py index 71fa19f..7e2ec66 100644 --- a/redis_cache/redis_cache.py +++ b/redis_cache/redis_cache.py @@ -58,7 +58,7 @@ def _default_signature_generator(*args, **kwargs): :return: arg1,...argn,kwarg1=kwarg1,...kwargn=kwargn """ # Join regular arguments together with commas - parsed_args = ",".join(map(lambda x: str(x), args[1])) + parsed_args = ",".join(map(_argument_to_string, args[1])) # Join keyword arguments together with `=` and commas parsed_kwargs = ",".join( map(lambda x: '%s=%s' % (x, str(kwargs[x])), kwargs) @@ -68,3 +68,25 @@ def _default_signature_generator(*args, **kwargs): lambda x: x != '', [parsed_args, parsed_kwargs] ) return ','.join(parsed) + + +def _argument_to_string(arg): + if arg.__class__.__module__ == '__builtin__': + return str(arg) + + # for backwards-compatibility: if the object defines a custom __str__ method + # use it instead of the default approach based on the class-name + string_representation = str(arg) + if string_representation != object.__str__(arg): + return string_representation + + try: + instance_namespace = arg.__class__.__name__ + instance_state = sorted((field, _argument_to_string(value)) + for field, value in vars(arg).items()) + return '{}_{}'.format(instance_namespace, instance_state) + + except TypeError: + # some non-primitive types (e.g., defaultdict) don't have a __dict__ so + # we need to have a fall-back + return str(arg) diff --git a/tests/fakes.py b/tests/fakes.py new file mode 100644 index 0000000..c7450cb --- /dev/null +++ b/tests/fakes.py @@ -0,0 +1,17 @@ +class FakeRedisClient(object): + """ + In-memory fake implementation of RedisClient + """ + def __init__(self): + self._cache = {} + + def get(self, key): + return self._cache.get(key, '') + + def set(self, key, value, **kwargs): + self._cache[key] = value + return '' + + def setex(self, key, value, expiration): + self._cache[key] = value + return '' diff --git a/tests/test_redis_cache.py b/tests/test_redis_cache.py index d677c25..e45b35e 100644 --- a/tests/test_redis_cache.py +++ b/tests/test_redis_cache.py @@ -1,8 +1,10 @@ from redis_cache.redis_cache import \ RedisCache, RedisException, DEFAULT_EXPIRATION +from fakes import FakeRedisClient from unittest import TestCase from mock import Mock, patch from inputs import SimpleObject +import collections import pickle @@ -247,3 +249,100 @@ def __str__(self): mock_client_object.assert_called_once_with(self.address, self.port) expected_hash = str(hash('cache_this_method' + str(test_class))) mock_client.get.assert_called_once_with(expected_hash) + + def test_cache_on_class_without_str_for_function_with_simple_args(self): + """ + Tests a cache hit when invoking a method that takes a primitive type as + argument on an object that does not implement a __str__ method + """ + redis_cache = RedisCache(self.address, self.port) + redis_cache.redis_client = FakeRedisClient() + + class TestClass(object): + def __init__(self): + self.call_count = collections.defaultdict(int) + + @redis_cache.cache() + def echo(self, parameter): + self.call_count[parameter] += 1 + return parameter + + primitive_argument = 'cache hit test' + + instance1 = TestClass() + instance2 = TestClass() + instance3 = TestClass() + + value1 = instance1.echo(primitive_argument) + value2 = instance2.echo(primitive_argument) + value3 = instance3.echo(primitive_argument) + + self.assertTrue(value1 == value2 == value3 == primitive_argument) + self.assertEqual(instance1.call_count[primitive_argument], 1) + self.assertEqual(instance2.call_count[primitive_argument], 0) + self.assertEqual(instance3.call_count[primitive_argument], 0) + + def test_cache_on_class_without_str_for_function_with_complex_args(self): + """ + Tests a cache hit when invoking a method that takes a complex type as + argument on an object that does not implement a __str__ method + """ + redis_cache = RedisCache(self.address, self.port) + redis_cache.redis_client = FakeRedisClient() + + class TestClass(object): + def __init__(self): + self.some_method_call_count = collections.defaultdict(int) + + @redis_cache.cache() + def some_method(self, parameter): + self.some_method_call_count[parameter] += 1 + return parameter + + complex_argument = SimpleObject('test', 42) + + instance1 = TestClass() + instance2 = TestClass() + instance3 = TestClass() + + value1 = instance1.some_method(complex_argument) + value2 = instance2.some_method(complex_argument) + value3 = instance3.some_method(complex_argument) + + self.assertTrue(value1 == value2 == value3 == complex_argument) + self.assertEqual(instance1.some_method_call_count[complex_argument], 1) + self.assertEqual(instance2.some_method_call_count[complex_argument], 0) + self.assertEqual(instance3.some_method_call_count[complex_argument], 0) + + def test_cache_on_stateful_class_without_str(self): + """ + Tests a cache hit when invoking a method on a stateful object. + """ + redis_cache = RedisCache(self.address, self.port) + redis_cache.redis_client = FakeRedisClient() + + class TestClass(object): + def __init__(self, state): + self.state = state + + @redis_cache.cache() + def some_method(self): + return self.state + + state1 = 'some state' + state2 = SimpleObject('foo', 2) + + state1_instance1 = TestClass(state1) + state1_instance2 = TestClass(state1) + state2_instance1 = TestClass(state2) + state2_instance2 = TestClass(state2) + + value1 = state1_instance1.some_method() + value2 = state1_instance1.some_method() + value3 = state1_instance2.some_method() + value4 = state2_instance1.some_method() + value5 = state2_instance2.some_method() + value6 = state2_instance1.some_method() + + self.assertTrue(value1 == value2 == value3 == state1) + self.assertTrue(value4 == value5 == value6 == state2)