Skip to content

Commit

Permalink
Fix #54: do deeper lookup for inhereted classes
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov committed Jul 14, 2015
1 parent 4f8eb43 commit 18708ab
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ CHANGES

* Implement monitoring ZMQ events #50

* Do deeper lookup for inhereted classes #54

0.6.1 (2015-05-19)
^^^^^^^^^^^^^^^^^^

Expand Down
8 changes: 7 additions & 1 deletion aiozmq/rpc/packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def __init__(self, *, translation_table=None):
translation_table = ChainMap(translation_table, _default)
self.translation_table = translation_table
self._pack_cache = {}
self._unpack_cache = {}
for code in sorted(self.translation_table):
cls, packer, unpacker = self.translation_table[code]
self._pack_cache[cls] = (code, packer)
self._unpack_cache[code] = unpacker

def packb(self, data):
return packb(data, encoding='utf-8', use_bin_type=True,
Expand All @@ -46,6 +51,7 @@ def ext_type_pack_hook(self, obj, _sentinel=object()):
cls, packer, unpacker = self.translation_table[code]
if isinstance(obj, cls):
self._pack_cache[obj_class] = (code, packer)
self._unpack_cache[code] = unpacker
return ExtType(code, packer(obj))
else:
self._pack_cache[obj_class] = None
Expand All @@ -57,7 +63,7 @@ def ext_type_pack_hook(self, obj, _sentinel=object()):

def ext_type_unpack_hook(self, code, data):
try:
cls, packer, unpacker = self.translation_table[code]
unpacker = self._unpack_cache[code]
return unpacker(data)
except KeyError:
return ExtType(code, data)
21 changes: 21 additions & 0 deletions tests/rpc_packer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,24 @@ def test_override_translators(self):
self.assertEqual(ExtType(125, data), packer.ext_type_pack_hook(pt))
with self.assertRaisesRegex(TypeError, "Unknown type: "):
packer.ext_type_pack_hook(dt)

def test_preserve_resolution_order(self):
class A:
pass

class B(A):
pass

dump_a = mock.Mock(return_value=b'a')
load_a = mock.Mock(return_value=A())

dump_b = mock.Mock(return_value=b'b')
load_b = mock.Mock(return_value=B())

translation_table = {
1: (A, dump_a, load_a),
2: (B, dump_b, load_b),
}
packer = _Packer(translation_table=translation_table)
self.assertEqual(packer.packb(ExtType(1, b'a')), packer.packb(A()))
self.assertEqual(packer.packb(ExtType(2, b'b')), packer.packb(B()))

0 comments on commit 18708ab

Please sign in to comment.