diff --git a/aioredis/connection.py b/aioredis/connection.py index c59fe8c2c..d641d024a 100644 --- a/aioredis/connection.py +++ b/aioredis/connection.py @@ -150,14 +150,14 @@ def _process_data(self, obj): def _process_pubsub(self, obj): """Processes pubsub messages.""" - # TODO: decode data kind, *pattern, chan, data = obj if self._in_pubsub and self._waiters: self._process_data(obj) if kind in (b'subscribe', b'unsubscribe'): if kind == b'subscribe' and chan not in self._pubsub_channels: - self._pubsub_channels[chan] = Channel(chan, loop=self._loop) + self._pubsub_channels[chan] = Channel(chan, is_pattern=False, + loop=self._loop) elif kind == b'unsubscribe': ch = self._pubsub_channels.pop(chan, None) if ch: @@ -165,7 +165,8 @@ def _process_pubsub(self, obj): self._in_pubsub = data elif kind in (b'psubscribe', b'punsubscribe'): if kind == b'psubscribe' and chan not in self._pubsub_patterns: - self._pubsub_patterns[chan] = Channel(chan, loop=self._loop) + self._pubsub_patterns[chan] = Channel(chan, is_pattern=True, + loop=self._loop) elif kind == b'punsubscribe': ch = self._pubsub_patterns.pop(chan, None) if ch: diff --git a/aioredis/util.py b/aioredis/util.py index e2e11d4bc..2baee6fb2 100644 --- a/aioredis/util.py +++ b/aioredis/util.py @@ -44,11 +44,12 @@ class Channel: """Wrapper around asyncio.Queue.""" __slots__ = ('_queue', '_name', '_closed', '_waiter', - '_loop') + '_is_pattern', '_loop') - def __init__(self, name, loop=None): + def __init__(self, name, is_pattern, loop=None): self._queue = asyncio.Queue(loop=loop) self._name = name + self._is_pattern = is_pattern self._loop = loop self._closed = False self._waiter = None @@ -58,6 +59,11 @@ def name(self): """Encoded channel name/pattern.""" return self._name + @property + def is_pattern(self): + """Set to True if channel is subscribed to pattern.""" + return self._is_pattern + @asyncio.coroutine def get(self, *, encoding=None, decoder=None): """Coroutine that waits for and returns a message. @@ -68,10 +74,14 @@ def get(self, *, encoding=None, decoder=None): if self._closed: pass # raise error msg = yield from self._queue.get() + if self._is_pattern: + dest_channel, msg = msg if encoding is not None: msg = msg.decode(encoding) if decoder is not None: msg = decoder(msg) + if self._is_pattern: + return dest_channel, msg return msg @asyncio.coroutine diff --git a/tests/connection_test.py b/tests/connection_test.py index 7c4918804..6c9c6ca56 100644 --- a/tests/connection_test.py +++ b/tests/connection_test.py @@ -230,24 +230,28 @@ def test_pubsub_messages(self): self.assertEqual(res, [b'subscribe', b'chan:1', 1]) self.assertIn(b'chan:1', sub.pubsub_channels) - queue = sub.pubsub_channels[b'chan:1'] + chan = sub.pubsub_channels[b'chan:1'] + self.assertEqual(chan.name, b'chan:1') + self.assertTrue(chan.is_active()) res = yield from pub.execute('publish', 'chan:1', 'Hello!') self.assertEqual(res, 1) - msg = yield from queue.get() + msg = yield from chan.get() self.assertEqual(msg, b'Hello!') res = yield from sub.execute('psubscribe', 'chan:*') self.assertEqual(res, [b'psubscribe', b'chan:*', 2]) self.assertIn(b'chan:*', sub.pubsub_patterns) - queue2 = sub.pubsub_patterns[b'chan:*'] + chan2 = sub.pubsub_patterns[b'chan:*'] + self.assertEqual(chan2.name, b'chan:*') + self.assertTrue(chan2.is_active()) res = yield from pub.execute('publish', 'chan:1', 'Hello!') self.assertEqual(res, 2) - msg = yield from queue.get() + msg = yield from chan.get() self.assertEqual(msg, b'Hello!') - dest_chan, msg = yield from queue2.get() + dest_chan, msg = yield from chan2.get() self.assertEqual(dest_chan, b'chan:1') self.assertEqual(msg, b'Hello!')