Skip to content
This repository has been archived by the owner on Feb 21, 2023. It is now read-only.

Commit

Permalink
few more bits
Browse files Browse the repository at this point in the history
  • Loading branch information
popravich committed May 25, 2015
1 parent 90118c8 commit d561755
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 10 deletions.
7 changes: 4 additions & 3 deletions aioredis/connection.py
Expand Up @@ -150,22 +150,23 @@ def _process_data(self, obj):

def _process_pubsub(self, obj):
"""Processes pubsub messages."""
# TODO: decode data
kind, *pattern, chan, data = obj
if self._in_pubsub and self._waiters:
self._process_data(obj)

if kind in (b'subscribe', b'unsubscribe'):
if kind == b'subscribe' and chan not in self._pubsub_channels:
self._pubsub_channels[chan] = Channel(chan, loop=self._loop)
self._pubsub_channels[chan] = Channel(chan, is_pattern=False,
loop=self._loop)
elif kind == b'unsubscribe':
ch = self._pubsub_channels.pop(chan, None)
if ch:
ch.close()
self._in_pubsub = data
elif kind in (b'psubscribe', b'punsubscribe'):
if kind == b'psubscribe' and chan not in self._pubsub_patterns:
self._pubsub_patterns[chan] = Channel(chan, loop=self._loop)
self._pubsub_patterns[chan] = Channel(chan, is_pattern=True,
loop=self._loop)
elif kind == b'punsubscribe':
ch = self._pubsub_patterns.pop(chan, None)
if ch:
Expand Down
14 changes: 12 additions & 2 deletions aioredis/util.py
Expand Up @@ -44,11 +44,12 @@ class Channel:
"""Wrapper around asyncio.Queue."""
__slots__ = ('_queue', '_name',
'_closed', '_waiter',
'_loop')
'_is_pattern', '_loop')

def __init__(self, name, loop=None):
def __init__(self, name, is_pattern, loop=None):
self._queue = asyncio.Queue(loop=loop)
self._name = name
self._is_pattern = is_pattern
self._loop = loop
self._closed = False
self._waiter = None
Expand All @@ -58,6 +59,11 @@ def name(self):
"""Encoded channel name/pattern."""
return self._name

@property
def is_pattern(self):
"""Set to True if channel is subscribed to pattern."""
return self._is_pattern

@asyncio.coroutine
def get(self, *, encoding=None, decoder=None):
"""Coroutine that waits for and returns a message.
Expand All @@ -68,10 +74,14 @@ def get(self, *, encoding=None, decoder=None):
if self._closed:
pass # raise error
msg = yield from self._queue.get()
if self._is_pattern:
dest_channel, msg = msg
if encoding is not None:
msg = msg.decode(encoding)
if decoder is not None:
msg = decoder(msg)
if self._is_pattern:
return dest_channel, msg
return msg

@asyncio.coroutine
Expand Down
14 changes: 9 additions & 5 deletions tests/connection_test.py
Expand Up @@ -230,24 +230,28 @@ def test_pubsub_messages(self):
self.assertEqual(res, [b'subscribe', b'chan:1', 1])

self.assertIn(b'chan:1', sub.pubsub_channels)
queue = sub.pubsub_channels[b'chan:1']
chan = sub.pubsub_channels[b'chan:1']
self.assertEqual(chan.name, b'chan:1')
self.assertTrue(chan.is_active())

res = yield from pub.execute('publish', 'chan:1', 'Hello!')
self.assertEqual(res, 1)
msg = yield from queue.get()
msg = yield from chan.get()
self.assertEqual(msg, b'Hello!')

res = yield from sub.execute('psubscribe', 'chan:*')
self.assertEqual(res, [b'psubscribe', b'chan:*', 2])
self.assertIn(b'chan:*', sub.pubsub_patterns)
queue2 = sub.pubsub_patterns[b'chan:*']
chan2 = sub.pubsub_patterns[b'chan:*']
self.assertEqual(chan2.name, b'chan:*')
self.assertTrue(chan2.is_active())

res = yield from pub.execute('publish', 'chan:1', 'Hello!')
self.assertEqual(res, 2)

msg = yield from queue.get()
msg = yield from chan.get()
self.assertEqual(msg, b'Hello!')
dest_chan, msg = yield from queue2.get()
dest_chan, msg = yield from chan2.get()
self.assertEqual(dest_chan, b'chan:1')
self.assertEqual(msg, b'Hello!')

Expand Down

0 comments on commit d561755

Please sign in to comment.