Skip to content

Commit

Permalink
Merge pull request #146 from aio-libs/testutils
Browse files Browse the repository at this point in the history
Remove usage of asyncio.test_utils
  • Loading branch information
JelleZijlstra committed Feb 5, 2019
2 parents aa8e0bc + 7570773 commit 4e6703c
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 15 deletions.
8 changes: 6 additions & 2 deletions tests/rpc_pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging

from unittest import mock
from asyncio.test_utils import run_briefly
from aiozmq._test_util import log_hook, RpcMixin


Expand Down Expand Up @@ -185,8 +184,13 @@ def communicate():
self.assertEqual(('suspicious',), ret.args)
self.assertIsNone(ret.exc_info)

@asyncio.coroutine
def dummy():
if False:
yield

self.loop.run_until_complete(communicate())
run_briefly(self.loop)
self.loop.run_until_complete(dummy())

def test_call_closed_pipeline(self):
client, server = self.make_pipeline_pair()
Expand Down
168 changes: 155 additions & 13 deletions tests/transport_test.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,168 @@
import unittest

import asyncio
import collections
import zmq
import aiozmq
import errno
import selectors
import weakref

from collections import deque

from asyncio import test_utils
from aiozmq.core import _ZmqTransportImpl, _ZmqLooplessTransportImpl
from unittest import mock

from aiozmq._test_util import check_errno


@asyncio.coroutine
def dummy():
if False:
yield


# make_test_protocol, TestSelector, and TestLoop were taken from
# test.test_asyncio.utils in CPython.
# https://github.com/python/cpython/blob/9602643120a509858d0bee4215d7f150e6125468/Lib/test/test_asyncio/utils.py

def make_test_protocol(base):
dct = {}
for name in dir(base):
if name.startswith('__') and name.endswith('__'):
# skip magic names
continue
dct[name] = mock.Mock(return_value=None)
return type('TestProtocol', (base,) + base.__bases__, dct)()


class TestSelector(selectors.BaseSelector):

def __init__(self):
self.keys = {}

def register(self, fileobj, events, data=None):
key = selectors.SelectorKey(fileobj, 0, events, data)
self.keys[fileobj] = key
return key

def unregister(self, fileobj):
return self.keys.pop(fileobj)

def select(self, timeout):
return []

def get_map(self):
return self.keys


class TestLoop(asyncio.base_events.BaseEventLoop):
def __init__(self):
super().__init__()

self._selector = TestSelector()

self.readers = {}
self.writers = {}
self.reset_counters()

self._transports = weakref.WeakValueDictionary()

def _add_reader(self, fd, callback, *args):
self.readers[fd] = asyncio.events.Handle(callback, args, self)

def _remove_reader(self, fd):
self.remove_reader_count[fd] += 1
if fd in self.readers:
del self.readers[fd]
return True
else:
return False

def assert_reader(self, fd, callback, *args):
if fd not in self.readers:
raise AssertionError('fd {fd} is not registered'.format(fd=fd))
handle = self.readers[fd]
if handle._callback != callback:
raise AssertionError(
'unexpected callback: {handle._callback} != {callback}'.format(
handle=handle, callback=callback))
if handle._args != args:
raise AssertionError(
'unexpected callback args: {handle._args} != {args}'.format(
handle=handle, args=args))

def assert_no_reader(self, fd):
if fd in self.readers:
raise AssertionError('fd {fd} is registered'.format(fd=fd))

def _add_writer(self, fd, callback, *args):
self.writers[fd] = asyncio.events.Handle(callback, args, self)

def _remove_writer(self, fd):
self.remove_writer_count[fd] += 1
if fd in self.writers:
del self.writers[fd]
return True
else:
return False

def assert_writer(self, fd, callback, *args):
assert fd in self.writers, 'fd {} is not registered'.format(fd)
handle = self.writers[fd]
assert handle._callback == callback, '{!r} != {!r}'.format(
handle._callback, callback)
assert handle._args == args, '{!r} != {!r}'.format(
handle._args, args)

def _ensure_fd_no_transport(self, fd):
try:
transport = self._transports[fd]
except KeyError:
pass
else:
raise RuntimeError(
'File descriptor {!r} is used by transport {!r}'.format(
fd, transport))

def add_reader(self, fd, callback, *args):
"""Add a reader callback."""
self._ensure_fd_no_transport(fd)
return self._add_reader(fd, callback, *args)

def remove_reader(self, fd):
"""Remove a reader callback."""
self._ensure_fd_no_transport(fd)
return self._remove_reader(fd)

def add_writer(self, fd, callback, *args):
"""Add a writer callback.."""
self._ensure_fd_no_transport(fd)
return self._add_writer(fd, callback, *args)

def remove_writer(self, fd):
"""Remove a writer callback."""
self._ensure_fd_no_transport(fd)
return self._remove_writer(fd)

def reset_counters(self):
self.remove_reader_count = collections.defaultdict(int)
self.remove_writer_count = collections.defaultdict(int)

def _process_events(self, event_list):
return

def _write_to_self(self):
pass


class TransportTests(unittest.TestCase):

def setUp(self):
self.loop = test_utils.TestLoop()
self.loop = TestLoop()
self.sock = mock.Mock()
self.sock.closed = False
self.proto = test_utils.make_test_protocol(aiozmq.ZmqProtocol)
self.proto = make_test_protocol(aiozmq.ZmqProtocol)
self.tr = _ZmqTransportImpl(self.loop, zmq.SUB, self.sock, self.proto)
self.exc_handler = mock.Mock()
self.loop.set_exception_handler(self.exc_handler)
Expand Down Expand Up @@ -144,7 +286,7 @@ def test_close_with_empty_buffer(self):
self.assertIsNotNone(self.tr._loop)
self.assertFalse(self.sock.close.called)

test_utils.run_briefly(self.loop)
self.loop.run_until_complete(dummy())

self.proto.connection_lost.assert_called_with(None)
self.assertIsNone(self.tr._protocol)
Expand All @@ -165,7 +307,7 @@ def test_close_already_closed_socket(self):
self.assertIsNotNone(self.tr._loop)
self.assertFalse(self.sock.close.called)

test_utils.run_briefly(self.loop)
self.loop.run_until_complete(dummy())

self.proto.connection_lost.assert_called_with(None)
self.assertIsNone(self.tr._protocol)
Expand All @@ -191,7 +333,7 @@ def test_close_with_waiting_buffer(self):
self.assertIsNotNone(self.tr._loop)
self.assertFalse(self.sock.close.called)

test_utils.run_briefly(self.loop)
self.loop.run_until_complete(dummy())

self.assertIsNotNone(self.tr._protocol)
self.assertIsNotNone(self.tr._zmq_sock)
Expand Down Expand Up @@ -234,7 +376,7 @@ def test_close_paused(self):
self.assertIsNotNone(self.tr._loop)
self.assertFalse(self.sock.close.called)

test_utils.run_briefly(self.loop)
self.loop.run_until_complete(dummy())

self.proto.connection_lost.assert_called_with(None)
self.assertIsNone(self.tr._protocol)
Expand Down Expand Up @@ -275,7 +417,7 @@ def test_abort_with_empty_buffer(self):
self.assertIsNotNone(self.tr._loop)
self.assertFalse(self.sock.close.called)

test_utils.run_briefly(self.loop)
self.loop.run_until_complete(dummy())

self.proto.connection_lost.assert_called_with(None)
self.assertIsNone(self.tr._protocol)
Expand All @@ -296,7 +438,7 @@ def test_abort_with_waiting_buffer(self):
self.assertEqual(0, self.tr._buffer_size)
self.assertTrue(self.tr._closing)

test_utils.run_briefly(self.loop)
self.loop.run_until_complete(dummy())

self.assertIsNone(self.tr._protocol)
self.assertIsNone(self.tr._zmq_sock)
Expand All @@ -318,7 +460,7 @@ def test_abort_with_close_on_waiting_buffer(self):
self.assertEqual(0, self.tr._buffer_size)
self.assertTrue(self.tr._closing)

test_utils.run_briefly(self.loop)
self.loop.run_until_complete(dummy())

self.assertIsNone(self.tr._protocol)
self.assertIsNone(self.tr._zmq_sock)
Expand All @@ -340,7 +482,7 @@ def test_abort_paused(self):
self.assertIsNotNone(self.tr._loop)
self.assertFalse(self.sock.close.called)

test_utils.run_briefly(self.loop)
self.loop.run_until_complete(dummy())

self.proto.connection_lost.assert_called_with(None)
self.assertIsNone(self.tr._protocol)
Expand Down Expand Up @@ -563,11 +705,11 @@ def test_conn_lost_on_force_close(self):
class LooplessTransportTests(unittest.TestCase):

def setUp(self):
self.loop = test_utils.TestLoop()
self.loop = TestLoop()
self.sock = mock.Mock()
self.sock.closed = False
self.waiter = asyncio.Future(loop=self.loop)
self.proto = test_utils.make_test_protocol(aiozmq.ZmqProtocol)
self.proto = make_test_protocol(aiozmq.ZmqProtocol)
self.tr = _ZmqLooplessTransportImpl(self.loop,
zmq.SUB, self.sock, self.proto,
self.waiter)
Expand Down
2 changes: 2 additions & 0 deletions tests/zmq_events_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def connect_rep():
return tr2, pr2

tr2, pr2 = self.loop.run_until_complete(connect_rep())
# Without this, this test hangs for some reason.
tr2._zmq_sock.getsockopt(zmq.EVENTS)

@asyncio.coroutine
def communicate():
Expand Down

0 comments on commit 4e6703c

Please sign in to comment.