diff --git a/Cargo.toml b/Cargo.toml index 44f9d3190f7..6d2c7948303 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -232,8 +232,11 @@ unsafe_op_in_unsafe_fn = "deny" elided_lifetimes_in_paths = "warn" [workspace.lints.clippy] +alloc_instead_of_core = "warn" +std_instead_of_alloc = "warn" +std_instead_of_core = "warn" perf = "warn" style = "warn" complexity = "warn" suspicious = "warn" -correctness = "warn" +correctness = "warn" \ No newline at end of file diff --git a/Lib/test/test_asyncio/__init__.py b/Lib/test/test_asyncio/__init__.py new file mode 100644 index 00000000000..ab0b5aa9489 --- /dev/null +++ b/Lib/test/test_asyncio/__init__.py @@ -0,0 +1,12 @@ +import os +from test import support +from test.support import load_package_tests +from test.support import import_helper + +support.requires_working_socket(module=True) + +# Skip tests if we don't have concurrent.futures. +import_helper.import_module('concurrent.futures') + +def load_tests(*args): + return load_package_tests(os.path.dirname(__file__), *args) diff --git a/Lib/test/test_asyncio/__main__.py b/Lib/test/test_asyncio/__main__.py new file mode 100644 index 00000000000..9a6ce77b534 --- /dev/null +++ b/Lib/test/test_asyncio/__main__.py @@ -0,0 +1,4 @@ +from test.test_asyncio import load_tests +import unittest + +unittest.main() diff --git a/Lib/test/test_asyncio/echo.py b/Lib/test/test_asyncio/echo.py new file mode 100644 index 00000000000..006364bb007 --- /dev/null +++ b/Lib/test/test_asyncio/echo.py @@ -0,0 +1,8 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + if not buf: + break + os.write(1, buf) diff --git a/Lib/test/test_asyncio/echo2.py b/Lib/test/test_asyncio/echo2.py new file mode 100644 index 00000000000..e83ca09fb7a --- /dev/null +++ b/Lib/test/test_asyncio/echo2.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + buf = os.read(0, 1024) + os.write(1, b'OUT:'+buf) + os.write(2, b'ERR:'+buf) diff --git a/Lib/test/test_asyncio/echo3.py b/Lib/test/test_asyncio/echo3.py new file mode 100644 index 00000000000..064496736bf --- /dev/null +++ b/Lib/test/test_asyncio/echo3.py @@ -0,0 +1,11 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + if not buf: + break + try: + os.write(1, b'OUT:'+buf) + except OSError as ex: + os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii')) diff --git a/Lib/test/test_asyncio/functional.py b/Lib/test/test_asyncio/functional.py new file mode 100644 index 00000000000..d19c7a612cc --- /dev/null +++ b/Lib/test/test_asyncio/functional.py @@ -0,0 +1,269 @@ +import asyncio +import asyncio.events +import contextlib +import os +import pprint +import select +import socket +import tempfile +import threading +from test import support + + +class FunctionalTestCaseMixin: + + def new_loop(self): + return asyncio.new_event_loop() + + def run_loop_briefly(self, *, delay=0.01): + self.loop.run_until_complete(asyncio.sleep(delay)) + + def loop_exception_handler(self, loop, context): + self.__unhandled_exceptions.append(context) + self.loop.default_exception_handler(context) + + def setUp(self): + self.loop = self.new_loop() + asyncio.set_event_loop(None) + + self.loop.set_exception_handler(self.loop_exception_handler) + self.__unhandled_exceptions = [] + + def tearDown(self): + try: + self.loop.close() + + if self.__unhandled_exceptions: + print('Unexpected calls to loop.call_exception_handler():') + pprint.pprint(self.__unhandled_exceptions) + self.fail('unexpected calls to loop.call_exception_handler()') + + finally: + asyncio.set_event_loop(None) + self.loop = None + + def tcp_server(self, server_prog, *, + family=socket.AF_INET, + addr=None, + timeout=support.LOOPBACK_TIMEOUT, + backlog=1, + max_clients=10): + + if addr is None: + if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX: + with tempfile.NamedTemporaryFile() as tmp: + addr = tmp.name + else: + addr = ('127.0.0.1', 0) + + sock = socket.create_server(addr, family=family, backlog=backlog) + if timeout is None: + raise RuntimeError('timeout is required') + if timeout <= 0: + raise RuntimeError('only blocking sockets are supported') + sock.settimeout(timeout) + + return TestThreadedServer( + self, sock, server_prog, timeout, max_clients) + + def tcp_client(self, client_prog, + family=socket.AF_INET, + timeout=support.LOOPBACK_TIMEOUT): + + sock = socket.socket(family, socket.SOCK_STREAM) + + if timeout is None: + raise RuntimeError('timeout is required') + if timeout <= 0: + raise RuntimeError('only blocking sockets are supported') + sock.settimeout(timeout) + + return TestThreadedClient( + self, sock, client_prog, timeout) + + def unix_server(self, *args, **kwargs): + if not hasattr(socket, 'AF_UNIX'): + raise NotImplementedError + return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs) + + def unix_client(self, *args, **kwargs): + if not hasattr(socket, 'AF_UNIX'): + raise NotImplementedError + return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs) + + @contextlib.contextmanager + def unix_sock_name(self): + with tempfile.TemporaryDirectory() as td: + fn = os.path.join(td, 'sock') + try: + yield fn + finally: + try: + os.unlink(fn) + except OSError: + pass + + def _abort_socket_test(self, ex): + try: + self.loop.stop() + finally: + self.fail(ex) + + +############################################################################## +# Socket Testing Utilities +############################################################################## + + +class TestSocketWrapper: + + def __init__(self, sock): + self.__sock = sock + + def recv_all(self, n): + buf = b'' + while len(buf) < n: + data = self.recv(n - len(buf)) + if data == b'': + raise ConnectionAbortedError + buf += data + return buf + + def start_tls(self, ssl_context, *, + server_side=False, + server_hostname=None): + + ssl_sock = ssl_context.wrap_socket( + self.__sock, server_side=server_side, + server_hostname=server_hostname, + do_handshake_on_connect=False) + + try: + ssl_sock.do_handshake() + except: + ssl_sock.close() + raise + finally: + self.__sock.close() + + self.__sock = ssl_sock + + def __getattr__(self, name): + return getattr(self.__sock, name) + + def __repr__(self): + return '<{} {!r}>'.format(type(self).__name__, self.__sock) + + +class SocketThread(threading.Thread): + + def stop(self): + self._active = False + self.join() + + def __enter__(self): + self.start() + return self + + def __exit__(self, *exc): + self.stop() + + +class TestThreadedClient(SocketThread): + + def __init__(self, test, sock, prog, timeout): + threading.Thread.__init__(self, None, None, 'test-client') + self.daemon = True + + self._timeout = timeout + self._sock = sock + self._active = True + self._prog = prog + self._test = test + + def run(self): + try: + self._prog(TestSocketWrapper(self._sock)) + except Exception as ex: + self._test._abort_socket_test(ex) + + +class TestThreadedServer(SocketThread): + + def __init__(self, test, sock, prog, timeout, max_clients): + threading.Thread.__init__(self, None, None, 'test-server') + self.daemon = True + + self._clients = 0 + self._finished_clients = 0 + self._max_clients = max_clients + self._timeout = timeout + self._sock = sock + self._active = True + + self._prog = prog + + self._s1, self._s2 = socket.socketpair() + self._s1.setblocking(False) + + self._test = test + + def stop(self): + try: + if self._s2 and self._s2.fileno() != -1: + try: + self._s2.send(b'stop') + except OSError: + pass + finally: + super().stop() + + def run(self): + try: + with self._sock: + self._sock.setblocking(False) + self._run() + finally: + self._s1.close() + self._s2.close() + + def _run(self): + while self._active: + if self._clients >= self._max_clients: + return + + r, w, x = select.select( + [self._sock, self._s1], [], [], self._timeout) + + if self._s1 in r: + return + + if self._sock in r: + try: + conn, addr = self._sock.accept() + except BlockingIOError: + continue + except TimeoutError: + if not self._active: + return + else: + raise + else: + self._clients += 1 + conn.settimeout(self._timeout) + try: + with conn: + self._handle_client(conn) + except Exception as ex: + self._active = False + try: + raise + finally: + self._test._abort_socket_test(ex) + + def _handle_client(self, sock): + self._prog(TestSocketWrapper(sock)) + + @property + def addr(self): + return self._sock.getsockname() diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py new file mode 100644 index 00000000000..62a77380773 --- /dev/null +++ b/Lib/test/test_asyncio/test_base_events.py @@ -0,0 +1,2311 @@ +"""Tests for base_events.py""" + +import concurrent.futures +import errno +import math +import platform +import socket +import sys +import threading +import time +import unittest +from unittest import mock + +import asyncio +from asyncio import base_events +from asyncio import constants +from test.test_asyncio import utils as test_utils +from test import support +from test.support.script_helper import assert_python_ok +from test.support import os_helper +from test.support import socket_helper +import warnings + +MOCK_ANY = mock.ANY + + +class CustomError(Exception): + pass + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +def mock_socket_module(): + m_socket = mock.MagicMock(spec=socket) + for name in ( + 'AF_INET', 'AF_INET6', 'AF_UNSPEC', 'IPPROTO_TCP', 'IPPROTO_UDP', + 'SOCK_STREAM', 'SOCK_DGRAM', 'SOL_SOCKET', 'SO_REUSEADDR', 'inet_pton' + ): + if hasattr(socket, name): + setattr(m_socket, name, getattr(socket, name)) + else: + delattr(m_socket, name) + + m_socket.socket = mock.MagicMock() + m_socket.socket.return_value = test_utils.mock_nonblocking_socket() + + return m_socket + + +def patch_socket(f): + return mock.patch('asyncio.base_events.socket', + new_callable=mock_socket_module)(f) + + +class BaseEventTests(test_utils.TestCase): + + def test_ipaddr_info(self): + UNSPEC = socket.AF_UNSPEC + INET = socket.AF_INET + INET6 = socket.AF_INET6 + STREAM = socket.SOCK_STREAM + DGRAM = socket.SOCK_DGRAM + TCP = socket.IPPROTO_TCP + UDP = socket.IPPROTO_UDP + + self.assertEqual( + (INET, STREAM, TCP, '', ('1.2.3.4', 1)), + base_events._ipaddr_info('1.2.3.4', 1, INET, STREAM, TCP)) + + self.assertEqual( + (INET, STREAM, TCP, '', ('1.2.3.4', 1)), + base_events._ipaddr_info(b'1.2.3.4', 1, INET, STREAM, TCP)) + + self.assertEqual( + (INET, STREAM, TCP, '', ('1.2.3.4', 1)), + base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, STREAM, TCP)) + + self.assertEqual( + (INET, DGRAM, UDP, '', ('1.2.3.4', 1)), + base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, DGRAM, UDP)) + + # Socket type STREAM implies TCP protocol. + self.assertEqual( + (INET, STREAM, TCP, '', ('1.2.3.4', 1)), + base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, STREAM, 0)) + + # Socket type DGRAM implies UDP protocol. + self.assertEqual( + (INET, DGRAM, UDP, '', ('1.2.3.4', 1)), + base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, DGRAM, 0)) + + # No socket type. + self.assertIsNone( + base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, 0, 0)) + + if socket_helper.IPV6_ENABLED: + # IPv4 address with family IPv6. + self.assertIsNone( + base_events._ipaddr_info('1.2.3.4', 1, INET6, STREAM, TCP)) + + self.assertEqual( + (INET6, STREAM, TCP, '', ('::3', 1, 0, 0)), + base_events._ipaddr_info('::3', 1, INET6, STREAM, TCP)) + + self.assertEqual( + (INET6, STREAM, TCP, '', ('::3', 1, 0, 0)), + base_events._ipaddr_info('::3', 1, UNSPEC, STREAM, TCP)) + + # IPv6 address with family IPv4. + self.assertIsNone( + base_events._ipaddr_info('::3', 1, INET, STREAM, TCP)) + + # IPv6 address with zone index. + self.assertIsNone( + base_events._ipaddr_info('::3%lo0', 1, INET6, STREAM, TCP)) + + def test_port_parameter_types(self): + # Test obscure kinds of arguments for "port". + INET = socket.AF_INET + STREAM = socket.SOCK_STREAM + TCP = socket.IPPROTO_TCP + + self.assertEqual( + (INET, STREAM, TCP, '', ('1.2.3.4', 0)), + base_events._ipaddr_info('1.2.3.4', None, INET, STREAM, TCP)) + + self.assertEqual( + (INET, STREAM, TCP, '', ('1.2.3.4', 0)), + base_events._ipaddr_info('1.2.3.4', b'', INET, STREAM, TCP)) + + self.assertEqual( + (INET, STREAM, TCP, '', ('1.2.3.4', 0)), + base_events._ipaddr_info('1.2.3.4', '', INET, STREAM, TCP)) + + self.assertEqual( + (INET, STREAM, TCP, '', ('1.2.3.4', 1)), + base_events._ipaddr_info('1.2.3.4', '1', INET, STREAM, TCP)) + + self.assertEqual( + (INET, STREAM, TCP, '', ('1.2.3.4', 1)), + base_events._ipaddr_info('1.2.3.4', b'1', INET, STREAM, TCP)) + + @patch_socket + def test_ipaddr_info_no_inet_pton(self, m_socket): + del m_socket.inet_pton + self.assertIsNone(base_events._ipaddr_info('1.2.3.4', 1, + socket.AF_INET, + socket.SOCK_STREAM, + socket.IPPROTO_TCP)) + + +class BaseEventLoopTests(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = base_events.BaseEventLoop() + self.loop._selector = mock.Mock() + self.loop._selector.select.return_value = () + self.set_event_loop(self.loop) + + def test_not_implemented(self): + m = mock.Mock() + self.assertRaises( + NotImplementedError, + self.loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.loop._process_events, []) + self.assertRaises( + NotImplementedError, self.loop._write_to_self) + self.assertRaises( + NotImplementedError, + self.loop._make_read_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_write_pipe_transport, m, m) + gen = self.loop._make_subprocess_transport(m, m, m, m, m, m, m) + with self.assertRaises(NotImplementedError): + gen.send(None) + + def test_close(self): + self.assertFalse(self.loop.is_closed()) + self.loop.close() + self.assertTrue(self.loop.is_closed()) + + # it should be possible to call close() more than once + self.loop.close() + self.loop.close() + + # operation blocked when the loop is closed + f = self.loop.create_future() + self.assertRaises(RuntimeError, self.loop.run_forever) + self.assertRaises(RuntimeError, self.loop.run_until_complete, f) + + def test__add_callback_handle(self): + h = asyncio.Handle(lambda: False, (), self.loop, None) + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertIn(h, self.loop._ready) + + def test__add_callback_cancelled_handle(self): + h = asyncio.Handle(lambda: False, (), self.loop, None) + h.cancel() + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertFalse(self.loop._ready) + + def test_set_default_executor(self): + class DummyExecutor(concurrent.futures.ThreadPoolExecutor): + def submit(self, fn, *args, **kwargs): + raise NotImplementedError( + 'cannot submit into a dummy executor') + + self.loop._process_events = mock.Mock() + self.loop._write_to_self = mock.Mock() + + executor = DummyExecutor() + self.loop.set_default_executor(executor) + self.assertIs(executor, self.loop._default_executor) + + def test_set_default_executor_error(self): + executor = mock.Mock() + + msg = 'executor must be ThreadPoolExecutor instance' + with self.assertRaisesRegex(TypeError, msg): + self.loop.set_default_executor(executor) + + self.assertIsNone(self.loop._default_executor) + + def test_shutdown_default_executor_timeout(self): + event = threading.Event() + + class DummyExecutor(concurrent.futures.ThreadPoolExecutor): + def shutdown(self, wait=True, *, cancel_futures=False): + if wait: + event.wait() + + self.loop._process_events = mock.Mock() + self.loop._write_to_self = mock.Mock() + executor = DummyExecutor() + self.loop.set_default_executor(executor) + + try: + with self.assertWarnsRegex(RuntimeWarning, + "The executor did not finishing joining"): + self.loop.run_until_complete( + self.loop.shutdown_default_executor(timeout=0.01)) + finally: + event.set() + + def test_call_soon(self): + def cb(): + pass + + h = self.loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, asyncio.Handle) + self.assertIn(h, self.loop._ready) + + def test_call_soon_non_callable(self): + self.loop.set_debug(True) + with self.assertRaisesRegex(TypeError, 'a callable object'): + self.loop.call_soon(1) + + def test_call_later(self): + def cb(): + pass + + h = self.loop.call_later(10.0, cb) + self.assertIsInstance(h, asyncio.TimerHandle) + self.assertIn(h, self.loop._scheduled) + self.assertNotIn(h, self.loop._ready) + with self.assertRaises(TypeError, msg="delay must not be None"): + self.loop.call_later(None, cb) + + def test_call_later_negative_delays(self): + calls = [] + + def cb(arg): + calls.append(arg) + + self.loop._process_events = mock.Mock() + self.loop.call_later(-1, cb, 'a') + self.loop.call_later(-2, cb, 'b') + test_utils.run_briefly(self.loop) + self.assertEqual(calls, ['b', 'a']) + + def test_time_and_call_at(self): + def cb(): + self.loop.stop() + + self.loop._process_events = mock.Mock() + delay = 0.100 + + when = self.loop.time() + delay + self.loop.call_at(when, cb) + t0 = self.loop.time() + self.loop.run_forever() + dt = self.loop.time() - t0 + + # 50 ms: maximum granularity of the event loop + self.assertGreaterEqual(dt, delay - test_utils.CLOCK_RES) + with self.assertRaises(TypeError, msg="when cannot be None"): + self.loop.call_at(None, cb) + + def check_thread(self, loop, debug): + def cb(): + pass + + loop.set_debug(debug) + if debug: + msg = ("Non-thread-safe operation invoked on an event loop other " + "than the current one") + with self.assertRaisesRegex(RuntimeError, msg): + loop.call_soon(cb) + with self.assertRaisesRegex(RuntimeError, msg): + loop.call_later(60, cb) + with self.assertRaisesRegex(RuntimeError, msg): + loop.call_at(loop.time() + 60, cb) + else: + loop.call_soon(cb) + loop.call_later(60, cb) + loop.call_at(loop.time() + 60, cb) + + def test_check_thread(self): + def check_in_thread(loop, event, debug, create_loop, fut): + # wait until the event loop is running + event.wait() + + try: + if create_loop: + loop2 = base_events.BaseEventLoop() + try: + asyncio.set_event_loop(loop2) + self.check_thread(loop, debug) + finally: + asyncio.set_event_loop(None) + loop2.close() + else: + self.check_thread(loop, debug) + except Exception as exc: + loop.call_soon_threadsafe(fut.set_exception, exc) + else: + loop.call_soon_threadsafe(fut.set_result, None) + + def test_thread(loop, debug, create_loop=False): + event = threading.Event() + fut = loop.create_future() + loop.call_soon(event.set) + args = (loop, event, debug, create_loop, fut) + thread = threading.Thread(target=check_in_thread, args=args) + thread.start() + loop.run_until_complete(fut) + thread.join() + + self.loop._process_events = mock.Mock() + self.loop._write_to_self = mock.Mock() + + # raise RuntimeError if the thread has no event loop + test_thread(self.loop, True) + + # check disabled if debug mode is disabled + test_thread(self.loop, False) + + # raise RuntimeError if the event loop of the thread is not the called + # event loop + test_thread(self.loop, True, create_loop=True) + + # check disabled if debug mode is disabled + test_thread(self.loop, False, create_loop=True) + + def test__run_once(self): + h1 = asyncio.TimerHandle(time.monotonic() + 5.0, lambda: True, (), + self.loop, None) + h2 = asyncio.TimerHandle(time.monotonic() + 10.0, lambda: True, (), + self.loop, None) + + h1.cancel() + + self.loop._process_events = mock.Mock() + self.loop._scheduled.append(h1) + self.loop._scheduled.append(h2) + self.loop._run_once() + + t = self.loop._selector.select.call_args[0][0] + self.assertTrue(9.5 < t < 10.5, t) + self.assertEqual([h2], self.loop._scheduled) + self.assertTrue(self.loop._process_events.called) + + def test_set_debug(self): + self.loop.set_debug(True) + self.assertTrue(self.loop.get_debug()) + self.loop.set_debug(False) + self.assertFalse(self.loop.get_debug()) + + def test__run_once_schedule_handle(self): + handle = None + processed = False + + def cb(loop): + nonlocal processed, handle + processed = True + handle = loop.call_soon(lambda: True) + + h = asyncio.TimerHandle(time.monotonic() - 1, cb, (self.loop,), + self.loop, None) + + self.loop._process_events = mock.Mock() + self.loop._scheduled.append(h) + self.loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handle], list(self.loop._ready)) + + def test__run_once_cancelled_event_cleanup(self): + self.loop._process_events = mock.Mock() + + self.assertTrue( + 0 < base_events._MIN_CANCELLED_TIMER_HANDLES_FRACTION < 1.0) + + def cb(): + pass + + # Set up one "blocking" event that will not be cancelled to + # ensure later cancelled events do not make it to the head + # of the queue and get cleaned. + not_cancelled_count = 1 + self.loop.call_later(3000, cb) + + # Add less than threshold (base_events._MIN_SCHEDULED_TIMER_HANDLES) + # cancelled handles, ensure they aren't removed + + cancelled_count = 2 + for x in range(2): + h = self.loop.call_later(3600, cb) + h.cancel() + + # Add some cancelled events that will be at head and removed + cancelled_count += 2 + for x in range(2): + h = self.loop.call_later(100, cb) + h.cancel() + + # This test is invalid if _MIN_SCHEDULED_TIMER_HANDLES is too low + self.assertLessEqual(cancelled_count + not_cancelled_count, + base_events._MIN_SCHEDULED_TIMER_HANDLES) + + self.assertEqual(self.loop._timer_cancelled_count, cancelled_count) + + self.loop._run_once() + + cancelled_count -= 2 + + self.assertEqual(self.loop._timer_cancelled_count, cancelled_count) + + self.assertEqual(len(self.loop._scheduled), + cancelled_count + not_cancelled_count) + + # Need enough events to pass _MIN_CANCELLED_TIMER_HANDLES_FRACTION + # so that deletion of cancelled events will occur on next _run_once + add_cancel_count = int(math.ceil( + base_events._MIN_SCHEDULED_TIMER_HANDLES * + base_events._MIN_CANCELLED_TIMER_HANDLES_FRACTION)) + 1 + + add_not_cancel_count = max(base_events._MIN_SCHEDULED_TIMER_HANDLES - + add_cancel_count, 0) + + # Add some events that will not be cancelled + not_cancelled_count += add_not_cancel_count + for x in range(add_not_cancel_count): + self.loop.call_later(3600, cb) + + # Add enough cancelled events + cancelled_count += add_cancel_count + for x in range(add_cancel_count): + h = self.loop.call_later(3600, cb) + h.cancel() + + # Ensure all handles are still scheduled + self.assertEqual(len(self.loop._scheduled), + cancelled_count + not_cancelled_count) + + self.loop._run_once() + + # Ensure cancelled events were removed + self.assertEqual(len(self.loop._scheduled), not_cancelled_count) + + # Ensure only uncancelled events remain scheduled + self.assertTrue(all([not x._cancelled for x in self.loop._scheduled])) + + def test_run_until_complete_type_error(self): + self.assertRaises(TypeError, + self.loop.run_until_complete, 'blah') + + def test_run_until_complete_loop(self): + task = self.loop.create_future() + other_loop = self.new_test_loop() + self.addCleanup(other_loop.close) + self.assertRaises(ValueError, + other_loop.run_until_complete, task) + + def test_run_until_complete_loop_orphan_future_close_loop(self): + class ShowStopper(SystemExit): + pass + + async def foo(delay): + await asyncio.sleep(delay) + + def throw(): + raise ShowStopper + + self.loop._process_events = mock.Mock() + self.loop.call_soon(throw) + with self.assertRaises(ShowStopper): + self.loop.run_until_complete(foo(0.1)) + + # This call fails if run_until_complete does not clean up + # done-callback for the previous future. + self.loop.run_until_complete(foo(0.2)) + + def test_subprocess_exec_invalid_args(self): + args = [sys.executable, '-c', 'pass'] + + # missing program parameter (empty args) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol) + + # expected multiple arguments, not a list + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, args) + + # program arguments must be strings, not int + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, sys.executable, 123) + + # universal_newlines, shell, bufsize must not be set + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, *args, universal_newlines=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, *args, shell=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, *args, bufsize=4096) + + def test_subprocess_shell_invalid_args(self): + # expected a string, not an int or a list + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 123) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, [sys.executable, '-c', 'pass']) + + # universal_newlines, shell, bufsize must not be set + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 'exit 0', universal_newlines=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 'exit 0', shell=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 'exit 0', bufsize=4096) + + def test_default_exc_handler_callback(self): + self.loop._process_events = mock.Mock() + + def zero_error(fut): + fut.set_result(True) + 1/0 + + # Test call_soon (events.Handle) + with mock.patch('asyncio.base_events.logger') as log: + fut = self.loop.create_future() + self.loop.call_soon(zero_error, fut) + fut.add_done_callback(lambda fut: self.loop.stop()) + self.loop.run_forever() + log.error.assert_called_with( + test_utils.MockPattern('Exception in callback.*zero'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + # Test call_later (events.TimerHandle) + with mock.patch('asyncio.base_events.logger') as log: + fut = self.loop.create_future() + self.loop.call_later(0.01, zero_error, fut) + fut.add_done_callback(lambda fut: self.loop.stop()) + self.loop.run_forever() + log.error.assert_called_with( + test_utils.MockPattern('Exception in callback.*zero'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + def test_default_exc_handler_coro(self): + self.loop._process_events = mock.Mock() + + async def zero_error_coro(): + await asyncio.sleep(0.01) + 1/0 + + # Test Future.__del__ + with mock.patch('asyncio.base_events.logger') as log: + fut = asyncio.ensure_future(zero_error_coro(), loop=self.loop) + fut.add_done_callback(lambda *args: self.loop.stop()) + self.loop.run_forever() + fut = None # Trigger Future.__del__ or futures._TracebackLogger + support.gc_collect() + # Future.__del__ in logs error with an actual exception context + log.error.assert_called_with( + test_utils.MockPattern('.*exception was never retrieved'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + def test_set_exc_handler_invalid(self): + with self.assertRaisesRegex(TypeError, 'A callable object or None'): + self.loop.set_exception_handler('spam') + + def test_set_exc_handler_custom(self): + def zero_error(): + 1/0 + + def run_loop(): + handle = self.loop.call_soon(zero_error) + self.loop._run_once() + return handle + + self.loop.set_debug(True) + self.loop._process_events = mock.Mock() + + self.assertIsNone(self.loop.get_exception_handler()) + mock_handler = mock.Mock() + self.loop.set_exception_handler(mock_handler) + self.assertIs(self.loop.get_exception_handler(), mock_handler) + handle = run_loop() + mock_handler.assert_called_with(self.loop, { + 'exception': MOCK_ANY, + 'message': test_utils.MockPattern( + 'Exception in callback.*zero_error'), + 'handle': handle, + 'source_traceback': handle._source_traceback, + }) + mock_handler.reset_mock() + + self.loop.set_exception_handler(None) + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + test_utils.MockPattern( + 'Exception in callback.*zero'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + self.assertFalse(mock_handler.called) + + def test_set_exc_handler_broken(self): + def run_loop(): + def zero_error(): + 1/0 + self.loop.call_soon(zero_error) + self.loop._run_once() + + def handler(loop, context): + raise AttributeError('spam') + + self.loop._process_events = mock.Mock() + + self.loop.set_exception_handler(handler) + + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + test_utils.MockPattern( + 'Unhandled error in exception handler'), + exc_info=(AttributeError, MOCK_ANY, MOCK_ANY)) + + def test_default_exc_handler_broken(self): + _context = None + + class Loop(base_events.BaseEventLoop): + + _selector = mock.Mock() + _process_events = mock.Mock() + + def default_exception_handler(self, context): + nonlocal _context + _context = context + # Simulates custom buggy "default_exception_handler" + raise ValueError('spam') + + loop = Loop() + self.addCleanup(loop.close) + asyncio.set_event_loop(loop) + + def run_loop(): + def zero_error(): + 1/0 + loop.call_soon(zero_error) + loop._run_once() + + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + 'Exception in default exception handler', + exc_info=True) + + def custom_handler(loop, context): + raise ValueError('ham') + + _context = None + loop.set_exception_handler(custom_handler) + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + test_utils.MockPattern('Exception in default exception.*' + 'while handling.*in custom'), + exc_info=True) + + # Check that original context was passed to default + # exception handler. + self.assertIn('context', _context) + self.assertIs(type(_context['context']['exception']), + ZeroDivisionError) + + def test_set_task_factory_invalid(self): + with self.assertRaisesRegex( + TypeError, 'task factory must be a callable or None'): + + self.loop.set_task_factory(1) + + self.assertIsNone(self.loop.get_task_factory()) + + def test_set_task_factory(self): + self.loop._process_events = mock.Mock() + + class MyTask(asyncio.Task): + pass + + async def coro(): + pass + + factory = lambda loop, coro: MyTask(coro, loop=loop) + + self.assertIsNone(self.loop.get_task_factory()) + self.loop.set_task_factory(factory) + self.assertIs(self.loop.get_task_factory(), factory) + + task = self.loop.create_task(coro()) + self.assertTrue(isinstance(task, MyTask)) + self.loop.run_until_complete(task) + + self.loop.set_task_factory(None) + self.assertIsNone(self.loop.get_task_factory()) + + task = self.loop.create_task(coro()) + self.assertTrue(isinstance(task, asyncio.Task)) + self.assertFalse(isinstance(task, MyTask)) + self.loop.run_until_complete(task) + + def test_env_var_debug(self): + code = '\n'.join(( + 'import asyncio', + 'loop = asyncio.new_event_loop()', + 'print(loop.get_debug())')) + + # Test with -E to not fail if the unit test was run with + # PYTHONASYNCIODEBUG set to a non-empty string + sts, stdout, stderr = assert_python_ok('-E', '-c', code) + self.assertEqual(stdout.rstrip(), b'False') + + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='', + PYTHONDEVMODE='') + self.assertEqual(stdout.rstrip(), b'False') + + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='1', + PYTHONDEVMODE='') + self.assertEqual(stdout.rstrip(), b'True') + + sts, stdout, stderr = assert_python_ok('-E', '-c', code, + PYTHONASYNCIODEBUG='1') + self.assertEqual(stdout.rstrip(), b'False') + + # -X dev + sts, stdout, stderr = assert_python_ok('-E', '-X', 'dev', + '-c', code) + self.assertEqual(stdout.rstrip(), b'True') + + def test_create_task(self): + class MyTask(asyncio.Task): + pass + + async def test(): + pass + + class EventLoop(base_events.BaseEventLoop): + def create_task(self, coro): + return MyTask(coro, loop=loop) + + loop = EventLoop() + self.set_event_loop(loop) + + coro = test() + task = asyncio.ensure_future(coro, loop=loop) + self.assertIsInstance(task, MyTask) + + # make warnings quiet + task._log_destroy_pending = False + coro.close() + + def test_create_task_error_closes_coro(self): + async def test(): + pass + loop = asyncio.new_event_loop() + loop.close() + with warnings.catch_warnings(record=True) as w: + with self.assertRaises(RuntimeError): + asyncio.ensure_future(test(), loop=loop) + self.assertEqual(len(w), 0) + + + def test_create_named_task_with_default_factory(self): + async def test(): + pass + + loop = asyncio.new_event_loop() + task = loop.create_task(test(), name='test_task') + try: + self.assertEqual(task.get_name(), 'test_task') + finally: + loop.run_until_complete(task) + loop.close() + + def test_create_named_task_with_custom_factory(self): + def task_factory(loop, coro, **kwargs): + return asyncio.Task(coro, loop=loop, **kwargs) + + async def test(): + pass + + loop = asyncio.new_event_loop() + loop.set_task_factory(task_factory) + task = loop.create_task(test(), name='test_task') + try: + self.assertEqual(task.get_name(), 'test_task') + finally: + loop.run_until_complete(task) + loop.close() + + def test_run_forever_keyboard_interrupt(self): + # Python issue #22601: ensure that the temporary task created by + # run_forever() consumes the KeyboardInterrupt and so don't log + # a warning + async def raise_keyboard_interrupt(): + raise KeyboardInterrupt + + self.loop._process_events = mock.Mock() + self.loop.call_exception_handler = mock.Mock() + + try: + self.loop.run_until_complete(raise_keyboard_interrupt()) + except KeyboardInterrupt: + pass + self.loop.close() + support.gc_collect() + + self.assertFalse(self.loop.call_exception_handler.called) + + def test_run_until_complete_baseexception(self): + # Python issue #22429: run_until_complete() must not schedule a pending + # call to stop() if the future raised a BaseException + async def raise_keyboard_interrupt(): + raise KeyboardInterrupt + + self.loop._process_events = mock.Mock() + + with self.assertRaises(KeyboardInterrupt): + self.loop.run_until_complete(raise_keyboard_interrupt()) + + def func(): + self.loop.stop() + func.called = True + func.called = False + self.loop.call_soon(self.loop.call_soon, func) + self.loop.run_forever() + self.assertTrue(func.called) + + def test_single_selecter_event_callback_after_stopping(self): + # Python issue #25593: A stopped event loop may cause event callbacks + # to run more than once. + event_sentinel = object() + callcount = 0 + doer = None + + def proc_events(event_list): + nonlocal doer + if event_sentinel in event_list: + doer = self.loop.call_soon(do_event) + + def do_event(): + nonlocal callcount + callcount += 1 + self.loop.call_soon(clear_selector) + + def clear_selector(): + doer.cancel() + self.loop._selector.select.return_value = () + + self.loop._process_events = proc_events + self.loop._selector.select.return_value = (event_sentinel,) + + for i in range(1, 3): + with self.subTest('Loop %d/2' % i): + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + self.assertEqual(callcount, 1) + + def test_run_once(self): + # Simple test for test_utils.run_once(). It may seem strange + # to have a test for this (the function isn't even used!) but + # it's a de-factor standard API for library tests. This tests + # the idiom: loop.call_soon(loop.stop); loop.run_forever(). + count = 0 + + def callback(): + nonlocal count + count += 1 + + self.loop._process_events = mock.Mock() + self.loop.call_soon(callback) + test_utils.run_once(self.loop) + self.assertEqual(count, 1) + + def test_run_forever_pre_stopped(self): + # Test that the old idiom for pre-stopping the loop works. + self.loop._process_events = mock.Mock() + self.loop.stop() + self.loop.run_forever() + self.loop._selector.select.assert_called_once_with(0) + + + @unittest.skip('TODO: RUSTPYTHON') + # 'BaseEventLoop' object has no attribute '_run_forever_setup + def test_custom_run_forever_integration(self): + # Test that the run_forever_setup() and run_forever_cleanup() primitives + # can be used to implement a custom run_forever loop. + self.loop._process_events = mock.Mock() + + count = 0 + + def callback(): + nonlocal count + count += 1 + + self.loop.call_soon(callback) + + # Set up the custom event loop + self.loop._run_forever_setup() + + # Confirm the loop has been started + self.assertEqual(asyncio.get_running_loop(), self.loop) + self.assertTrue(self.loop.is_running()) + + # Our custom "event loop" just iterates 10 times before exiting. + for i in range(10): + self.loop._run_once() + + # Clean up the event loop + self.loop._run_forever_cleanup() + + # Confirm the loop has been cleaned up + with self.assertRaises(RuntimeError): + asyncio.get_running_loop() + self.assertFalse(self.loop.is_running()) + + # Confirm the loop actually did run, processing events 10 times, + # and invoking the callback once. + self.assertEqual(self.loop._process_events.call_count, 10) + self.assertEqual(count, 1) + + async def leave_unfinalized_asyncgen(self): + # Create an async generator, iterate it partially, and leave it + # to be garbage collected. + # Used in async generator finalization tests. + # Depends on implementation details of garbage collector. Changes + # in gc may break this function. + status = {'started': False, + 'stopped': False, + 'finalized': False} + + async def agen(): + status['started'] = True + try: + for item in ['ZERO', 'ONE', 'TWO', 'THREE', 'FOUR']: + yield item + finally: + status['finalized'] = True + + ag = agen() + ai = ag.__aiter__() + + async def iter_one(): + try: + item = await ai.__anext__() + except StopAsyncIteration: + return + if item == 'THREE': + status['stopped'] = True + return + asyncio.create_task(iter_one()) + + asyncio.create_task(iter_one()) + return status + + # TODO: RUSTPYTHON + # self.assertTrue(status['finalized']) + # AssertionError: False is not true + @unittest.expectedFailure + def test_asyncgen_finalization_by_gc(self): + # Async generators should be finalized when garbage collected. + self.loop._process_events = mock.Mock() + self.loop._write_to_self = mock.Mock() + with support.disable_gc(): + status = self.loop.run_until_complete(self.leave_unfinalized_asyncgen()) + while not status['stopped']: + test_utils.run_briefly(self.loop) + self.assertTrue(status['started']) + self.assertTrue(status['stopped']) + self.assertFalse(status['finalized']) + support.gc_collect() + test_utils.run_briefly(self.loop) + self.assertTrue(status['finalized']) + + # TODO: RUSTPYTHON + # self.assertTrue(status['finalized']) + # AssertionError: False is not true + @unittest.expectedFailure + def test_asyncgen_finalization_by_gc_in_other_thread(self): + # Python issue 34769: If garbage collector runs in another + # thread, async generators will not finalize in debug + # mode. + self.loop._process_events = mock.Mock() + self.loop._write_to_self = mock.Mock() + self.loop.set_debug(True) + with support.disable_gc(): + status = self.loop.run_until_complete(self.leave_unfinalized_asyncgen()) + while not status['stopped']: + test_utils.run_briefly(self.loop) + self.assertTrue(status['started']) + self.assertTrue(status['stopped']) + self.assertFalse(status['finalized']) + self.loop.run_until_complete( + self.loop.run_in_executor(None, support.gc_collect)) + test_utils.run_briefly(self.loop) + self.assertTrue(status['finalized']) + + +class MyProto(asyncio.Protocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = asyncio.get_running_loop().create_future() + + def _assert_state(self, *expected): + if self.state not in expected: + raise AssertionError(f'state: {self.state!r}, expected: {expected!r}') + + def connection_made(self, transport): + self.transport = transport + self._assert_state('INITIAL') + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + self._assert_state('CONNECTED') + self.nbytes += len(data) + + def eof_received(self): + self._assert_state('CONNECTED') + self.state = 'EOF' + + def connection_lost(self, exc): + self._assert_state('CONNECTED', 'EOF') + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(asyncio.DatagramProtocol): + done = None + + def __init__(self, create_future=False, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = loop.create_future() + + def _assert_state(self, expected): + if self.state != expected: + raise AssertionError(f'state: {self.state!r}, expected: {expected!r}') + + def connection_made(self, transport): + self.transport = transport + self._assert_state('INITIAL') + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + self._assert_state('INITIALIZED') + self.nbytes += len(data) + + def error_received(self, exc): + self._assert_state('INITIALIZED') + + def connection_lost(self, exc): + self._assert_state('INITIALIZED') + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class BaseEventLoopWithSelectorTests(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = asyncio.SelectorEventLoop() + self.set_event_loop(self.loop) + + @mock.patch('socket.getnameinfo') + def test_getnameinfo(self, m_gai): + m_gai.side_effect = lambda *args: 42 + r = self.loop.run_until_complete(self.loop.getnameinfo(('abc', 123))) + self.assertEqual(r, 42) + + @patch_socket + def test_create_connection_multiple_errors(self, m_socket): + + class MyProto(asyncio.Protocol): + pass + + async def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return self.loop.create_task(getaddrinfo(*args, **kwds)) + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise OSError(errors[idx]) + + m_socket.socket = _socket + + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2') + + idx = -1 + coro = self.loop.create_connection(MyProto, 'example.com', 80, all_errors=True) + with self.assertRaises(ExceptionGroup) as cm: + self.loop.run_until_complete(coro) + + self.assertIsInstance(cm.exception, ExceptionGroup) + for e in cm.exception.exceptions: + self.assertIsInstance(e, OSError) + + @patch_socket + def test_create_connection_timeout(self, m_socket): + # Ensure that the socket is closed on timeout + sock = mock.Mock() + m_socket.socket.return_value = sock + + def getaddrinfo(*args, **kw): + fut = self.loop.create_future() + addr = (socket.AF_INET, socket.SOCK_STREAM, 0, '', + ('127.0.0.1', 80)) + fut.set_result([addr]) + return fut + self.loop.getaddrinfo = getaddrinfo + + with mock.patch.object(self.loop, 'sock_connect', + side_effect=asyncio.TimeoutError): + coro = self.loop.create_connection(MyProto, '127.0.0.1', 80) + with self.assertRaises(asyncio.TimeoutError): + self.loop.run_until_complete(coro) + self.assertTrue(sock.close.called) + + + @patch_socket + @unittest.skip('TODO: RUSTPYTHON') + def test_create_connection_happy_eyeballs_empty_exceptions(self, m_socket): + # See gh-135836: Fix IndexError when Happy Eyeballs algorithm + # results in empty exceptions list + + async def getaddrinfo(*args, **kw): + return [(socket.AF_INET, socket.SOCK_STREAM, 0, '', ('127.0.0.1', 80)), + (socket.AF_INET6, socket.SOCK_STREAM, 0, '', ('::1', 80))] + + def getaddrinfo_task(*args, **kwds): + return self.loop.create_task(getaddrinfo(*args, **kwds)) + + self.loop.getaddrinfo = getaddrinfo_task + + # Mock staggered_race to return empty exceptions list + # This simulates the scenario where Happy Eyeballs algorithm + # cancels all attempts but doesn't properly collect exceptions + with mock.patch('asyncio.staggered.staggered_race') as mock_staggered: + # Return (None, []) - no winner, empty exceptions list + async def mock_race(coro_fns, delay, loop): + return None, [] + mock_staggered.side_effect = mock_race + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, happy_eyeballs_delay=0.1) + + # Should raise TimeoutError instead of IndexError + with self.assertRaisesRegex(TimeoutError, "create_connection failed"): + self.loop.run_until_complete(coro) + + def test_create_connection_host_port_sock(self): + coro = self.loop.create_connection( + MyProto, 'example.com', 80, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_wrong_sock(self): + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + with sock: + coro = self.loop.create_connection(MyProto, sock=sock) + with self.assertRaisesRegex(ValueError, + 'A Stream Socket was expected'): + self.loop.run_until_complete(coro) + + def test_create_server_wrong_sock(self): + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + with sock: + coro = self.loop.create_server(MyProto, sock=sock) + with self.assertRaisesRegex(ValueError, + 'A Stream Socket was expected'): + self.loop.run_until_complete(coro) + + def test_create_server_ssl_timeout_for_plain_socket(self): + coro = self.loop.create_server( + MyProto, 'example.com', 80, ssl_handshake_timeout=1) + with self.assertRaisesRegex( + ValueError, + 'ssl_handshake_timeout is only meaningful with ssl'): + self.loop.run_until_complete(coro) + + @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'), + 'no socket.SOCK_NONBLOCK (linux only)') + @unittest.skip('TODO: RUSTPYTHON') + def test_create_server_stream_bittype(self): + sock = socket.socket( + socket.AF_INET, socket.SOCK_STREAM | socket.SOCK_NONBLOCK) + with sock: + coro = self.loop.create_server(lambda: None, sock=sock) + srv = self.loop.run_until_complete(coro) + srv.close() + self.loop.run_until_complete(srv.wait_closed()) + + @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'no IPv6 support') + def test_create_server_ipv6(self): + async def main(): + srv = await asyncio.start_server(lambda: None, '::1', 0) + try: + self.assertGreater(len(srv.sockets), 0) + finally: + srv.close() + await srv.wait_closed() + + try: + self.loop.run_until_complete(main()) + except OSError as ex: + if (hasattr(errno, 'EADDRNOTAVAIL') and + ex.errno == errno.EADDRNOTAVAIL): + self.skipTest('failed to bind to ::1') + else: + raise + + def test_create_datagram_endpoint_wrong_sock(self): + sock = socket.socket(socket.AF_INET) + with sock: + coro = self.loop.create_datagram_endpoint(MyProto, sock=sock) + with self.assertRaisesRegex(ValueError, + 'A datagram socket was expected'): + self.loop.run_until_complete(coro) + + def test_create_connection_no_host_port_sock(self): + coro = self.loop.create_connection(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_getaddrinfo(self): + async def getaddrinfo(*args, **kw): + return [] + + def getaddrinfo_task(*args, **kwds): + return self.loop.create_task(getaddrinfo(*args, **kwds)) + + self.loop.getaddrinfo = getaddrinfo_task + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_connect_err(self): + async def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return self.loop.create_task(getaddrinfo(*args, **kwds)) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + coro = self.loop.create_connection(MyProto, 'example.com', 80, all_errors=True) + with self.assertRaises(ExceptionGroup) as cm: + self.loop.run_until_complete(coro) + + self.assertIsInstance(cm.exception, ExceptionGroup) + self.assertEqual(len(cm.exception.exceptions), 1) + self.assertIsInstance(cm.exception.exceptions[0], OSError) + + @patch_socket + @unittest.skip('TODO: RUSTPYTHON') + def test_create_connection_connect_non_os_err_close_err(self, m_socket): + # Test the case when sock_connect() raises non-OSError exception + # and sock.close() raises OSError. + async def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return self.loop.create_task(getaddrinfo(*args, **kwds)) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.side_effect = CustomError + sock = mock.Mock() + m_socket.socket.return_value = sock + sock.close.side_effect = OSError + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + CustomError, self.loop.run_until_complete, coro) + + coro = self.loop.create_connection(MyProto, 'example.com', 80, all_errors=True) + self.assertRaises( + CustomError, self.loop.run_until_complete, coro) + + def test_create_connection_multiple(self): + async def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return self.loop.create_task(getaddrinfo(*args, **kwds)) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET) + with self.assertRaises(OSError): + self.loop.run_until_complete(coro) + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, all_errors=True) + with self.assertRaises(ExceptionGroup) as cm: + self.loop.run_until_complete(coro) + + self.assertIsInstance(cm.exception, ExceptionGroup) + for e in cm.exception.exceptions: + self.assertIsInstance(e, OSError) + + @patch_socket + def test_create_connection_multiple_errors_local_addr(self, m_socket): + + def bind(addr): + if addr[0] == '0.0.0.1': + err = OSError('Err') + err.strerror = 'Err' + raise err + + m_socket.socket.return_value.bind = bind + + async def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return self.loop.create_task(getaddrinfo(*args, **kwds)) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.side_effect = OSError('Err2') + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertTrue(str(cm.exception).startswith('Multiple exceptions: ')) + self.assertTrue(m_socket.socket.return_value.close.called) + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080), all_errors=True) + with self.assertRaises(ExceptionGroup) as cm: + self.loop.run_until_complete(coro) + + self.assertIsInstance(cm.exception, ExceptionGroup) + for e in cm.exception.exceptions: + self.assertIsInstance(e, OSError) + + def _test_create_connection_ip_addr(self, m_socket, allow_inet_pton): + # Test the fallback code, even if this system has inet_pton. + if not allow_inet_pton: + del m_socket.inet_pton + + m_socket.getaddrinfo = socket.getaddrinfo + sock = m_socket.socket.return_value + + self.loop._add_reader = mock.Mock() + self.loop._add_writer = mock.Mock() + + coro = self.loop.create_connection(asyncio.Protocol, '1.2.3.4', 80) + t, p = self.loop.run_until_complete(coro) + try: + sock.connect.assert_called_with(('1.2.3.4', 80)) + _, kwargs = m_socket.socket.call_args + self.assertEqual(kwargs['family'], m_socket.AF_INET) + self.assertEqual(kwargs['type'], m_socket.SOCK_STREAM) + finally: + t.close() + test_utils.run_briefly(self.loop) # allow transport to close + + if socket_helper.IPV6_ENABLED: + sock.family = socket.AF_INET6 + coro = self.loop.create_connection(asyncio.Protocol, '::1', 80) + t, p = self.loop.run_until_complete(coro) + try: + # Without inet_pton we use getaddrinfo, which transforms + # ('::1', 80) to ('::1', 80, 0, 0). The last 0s are flow info, + # scope id. + [address] = sock.connect.call_args[0] + host, port = address[:2] + self.assertRegex(host, r'::(0\.)*1') + self.assertEqual(port, 80) + _, kwargs = m_socket.socket.call_args + self.assertEqual(kwargs['family'], m_socket.AF_INET6) + self.assertEqual(kwargs['type'], m_socket.SOCK_STREAM) + finally: + t.close() + test_utils.run_briefly(self.loop) # allow transport to close + + @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'no IPv6 support') + @unittest.skipIf(sys.platform.startswith('aix'), + "bpo-25545: IPv6 scope id and getaddrinfo() behave differently on AIX") + @patch_socket + def test_create_connection_ipv6_scope(self, m_socket): + m_socket.getaddrinfo = socket.getaddrinfo + sock = m_socket.socket.return_value + sock.family = socket.AF_INET6 + + self.loop._add_reader = mock.Mock() + self.loop._add_writer = mock.Mock() + + coro = self.loop.create_connection(asyncio.Protocol, 'fe80::1%1', 80) + t, p = self.loop.run_until_complete(coro) + try: + sock.connect.assert_called_with(('fe80::1', 80, 0, 1)) + _, kwargs = m_socket.socket.call_args + self.assertEqual(kwargs['family'], m_socket.AF_INET6) + self.assertEqual(kwargs['type'], m_socket.SOCK_STREAM) + finally: + t.close() + test_utils.run_briefly(self.loop) # allow transport to close + + @patch_socket + def test_create_connection_ip_addr(self, m_socket): + self._test_create_connection_ip_addr(m_socket, True) + + @patch_socket + def test_create_connection_no_inet_pton(self, m_socket): + self._test_create_connection_ip_addr(m_socket, False) + + @patch_socket + @unittest.skipIf( + support.is_android and platform.android_ver().api_level < 23, + "Issue gh-71123: this fails on Android before API level 23" + ) + @unittest.skip('TODO: RUSTPYTHON') + def test_create_connection_service_name(self, m_socket): + m_socket.getaddrinfo = socket.getaddrinfo + sock = m_socket.socket.return_value + + self.loop._add_reader = mock.Mock() + self.loop._add_writer = mock.Mock() + + for service, port in ('http', 80), (b'http', 80): + coro = self.loop.create_connection(asyncio.Protocol, + '127.0.0.1', service) + + t, p = self.loop.run_until_complete(coro) + try: + sock.connect.assert_called_with(('127.0.0.1', port)) + _, kwargs = m_socket.socket.call_args + self.assertEqual(kwargs['family'], m_socket.AF_INET) + self.assertEqual(kwargs['type'], m_socket.SOCK_STREAM) + finally: + t.close() + test_utils.run_briefly(self.loop) # allow transport to close + + for service in 'nonsense', b'nonsense': + coro = self.loop.create_connection(asyncio.Protocol, + '127.0.0.1', service) + + with self.assertRaises(OSError): + self.loop.run_until_complete(coro) + + def test_create_connection_no_local_addr(self): + async def getaddrinfo(host, *args, **kw): + if host == 'example.com': + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + else: + return [] + + def getaddrinfo_task(*args, **kwds): + return self.loop.create_task(getaddrinfo(*args, **kwds)) + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + @patch_socket + def test_create_connection_bluetooth(self, m_socket): + # See http://bugs.python.org/issue27136, fallback to getaddrinfo when + # we can't recognize an address is resolved, e.g. a Bluetooth address. + addr = ('00:01:02:03:04:05', 1) + + def getaddrinfo(host, port, *args, **kw): + self.assertEqual((host, port), addr) + return [(999, 1, 999, '', (addr, 1))] + + m_socket.getaddrinfo = getaddrinfo + sock = m_socket.socket() + coro = self.loop.sock_connect(sock, addr) + self.loop.run_until_complete(coro) + + def test_create_connection_ssl_server_hostname_default(self): + self.loop.getaddrinfo = mock.Mock() + + def mock_getaddrinfo(*args, **kwds): + f = self.loop.create_future() + f.set_result([(socket.AF_INET, socket.SOCK_STREAM, + socket.SOL_TCP, '', ('1.2.3.4', 80))]) + return f + + self.loop.getaddrinfo.side_effect = mock_getaddrinfo + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.return_value = self.loop.create_future() + self.loop.sock_connect.return_value.set_result(None) + self.loop._make_ssl_transport = mock.Mock() + + class _SelectorTransportMock: + _sock = None + + def get_extra_info(self, key): + return mock.Mock() + + def close(self): + self._sock.close() + + def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, + **kwds): + waiter.set_result(None) + transport = _SelectorTransportMock() + transport._sock = sock + return transport + + self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport + ANY = mock.ANY + handshake_timeout = object() + shutdown_timeout = object() + # First try the default server_hostname. + self.loop._make_ssl_transport.reset_mock() + coro = self.loop.create_connection( + MyProto, 'python.org', 80, ssl=True, + ssl_handshake_timeout=handshake_timeout, + ssl_shutdown_timeout=shutdown_timeout) + transport, _ = self.loop.run_until_complete(coro) + transport.close() + self.loop._make_ssl_transport.assert_called_with( + ANY, ANY, ANY, ANY, + server_side=False, + server_hostname='python.org', + ssl_handshake_timeout=handshake_timeout, + ssl_shutdown_timeout=shutdown_timeout) + # Next try an explicit server_hostname. + self.loop._make_ssl_transport.reset_mock() + coro = self.loop.create_connection( + MyProto, 'python.org', 80, ssl=True, + server_hostname='perl.com', + ssl_handshake_timeout=handshake_timeout, + ssl_shutdown_timeout=shutdown_timeout) + transport, _ = self.loop.run_until_complete(coro) + transport.close() + self.loop._make_ssl_transport.assert_called_with( + ANY, ANY, ANY, ANY, + server_side=False, + server_hostname='perl.com', + ssl_handshake_timeout=handshake_timeout, + ssl_shutdown_timeout=shutdown_timeout) + # Finally try an explicit empty server_hostname. + self.loop._make_ssl_transport.reset_mock() + coro = self.loop.create_connection( + MyProto, 'python.org', 80, ssl=True, + server_hostname='', + ssl_handshake_timeout=handshake_timeout, + ssl_shutdown_timeout=shutdown_timeout) + transport, _ = self.loop.run_until_complete(coro) + transport.close() + self.loop._make_ssl_transport.assert_called_with( + ANY, ANY, ANY, ANY, + server_side=False, + server_hostname='', + ssl_handshake_timeout=handshake_timeout, + ssl_shutdown_timeout=shutdown_timeout) + + def test_create_connection_no_ssl_server_hostname_errors(self): + # When not using ssl, server_hostname must be None. + coro = self.loop.create_connection(MyProto, 'python.org', 80, + server_hostname='') + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + coro = self.loop.create_connection(MyProto, 'python.org', 80, + server_hostname='python.org') + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_ssl_server_hostname_errors(self): + # When using ssl, server_hostname may be None if host is non-empty. + coro = self.loop.create_connection(MyProto, '', 80, ssl=True) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + coro = self.loop.create_connection(MyProto, None, 80, ssl=True) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + sock = socket.socket() + coro = self.loop.create_connection(MyProto, None, None, + ssl=True, sock=sock) + self.addCleanup(sock.close) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_ssl_timeout_for_plain_socket(self): + coro = self.loop.create_connection( + MyProto, 'example.com', 80, ssl_handshake_timeout=1) + with self.assertRaisesRegex( + ValueError, + 'ssl_handshake_timeout is only meaningful with ssl'): + self.loop.run_until_complete(coro) + + def test_create_server_empty_host(self): + # if host is empty string use None instead + host = object() + + async def getaddrinfo(*args, **kw): + nonlocal host + host = args[0] + return [] + + def getaddrinfo_task(*args, **kwds): + return self.loop.create_task(getaddrinfo(*args, **kwds)) + + self.loop.getaddrinfo = getaddrinfo_task + fut = self.loop.create_server(MyProto, '', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertIsNone(host) + + def test_create_server_host_port_sock(self): + fut = self.loop.create_server( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_create_server_no_host_port_sock(self): + fut = self.loop.create_server(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_create_server_no_getaddrinfo(self): + getaddrinfo = self.loop.getaddrinfo = mock.Mock() + getaddrinfo.return_value = self.loop.create_future() + getaddrinfo.return_value.set_result(None) + + f = self.loop.create_server(MyProto, 'python.org', 0) + self.assertRaises(OSError, self.loop.run_until_complete, f) + + @patch_socket + def test_create_server_nosoreuseport(self, m_socket): + m_socket.getaddrinfo = socket.getaddrinfo + del m_socket.SO_REUSEPORT + m_socket.socket.return_value = mock.Mock() + + f = self.loop.create_server( + MyProto, '0.0.0.0', 0, reuse_port=True) + + self.assertRaises(ValueError, self.loop.run_until_complete, f) + + @patch_socket + def test_create_server_soreuseport_only_defined(self, m_socket): + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.return_value = mock.Mock() + m_socket.SO_REUSEPORT = -1 + + f = self.loop.create_server( + MyProto, '0.0.0.0', 0, reuse_port=True) + + self.assertRaises(ValueError, self.loop.run_until_complete, f) + + @patch_socket + def test_create_server_cant_bind(self, m_socket): + + class Err(OSError): + strerror = 'error' + + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_server(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + @patch_socket + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + m_socket.getaddrinfo.return_value = [] + + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_addr_error(self): + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + TypeError, self.loop.run_until_complete, coro) + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + TypeError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_connect_err(self): + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_allow_broadcast(self): + protocol = MyDatagramProto(create_future=True, loop=self.loop) + self.loop.sock_connect = sock_connect = mock.Mock() + sock_connect.return_value = [] + + coro = self.loop.create_datagram_endpoint( + lambda: protocol, + remote_addr=('127.0.0.1', 0), + allow_broadcast=True) + + transport, _ = self.loop.run_until_complete(coro) + self.assertFalse(sock_connect.called) + + transport.close() + self.loop.run_until_complete(protocol.done) + self.assertEqual('CLOSED', protocol.state) + + @patch_socket + def test_create_datagram_endpoint_socket_err(self, m_socket): + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 not supported or enabled') + def test_create_datagram_endpoint_no_matching_family(self): + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.loop.run_until_complete, coro) + + @patch_socket + def test_create_datagram_endpoint_setblk_err(self, m_socket): + m_socket.socket.return_value.setblocking.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_create_datagram_endpoint_noaddr_nofamily(self): + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + @patch_socket + def test_create_datagram_endpoint_cant_bind(self, m_socket): + class Err(OSError): + pass + + m_socket.getaddrinfo = socket.getaddrinfo + m_sock = m_socket.socket.return_value = mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) + self.assertRaises(Err, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_create_datagram_endpoint_sock(self): + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.bind(('127.0.0.1', 0)) + fut = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(create_future=True, loop=self.loop), + sock=sock) + transport, protocol = self.loop.run_until_complete(fut) + transport.close() + self.loop.run_until_complete(protocol.done) + self.assertEqual('CLOSED', protocol.state) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_datagram_endpoint_sock_unix(self): + fut = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(create_future=True, loop=self.loop), + family=socket.AF_UNIX) + transport, protocol = self.loop.run_until_complete(fut) + self.assertEqual(transport._sock.family, socket.AF_UNIX) + transport.close() + self.loop.run_until_complete(protocol.done) + self.assertEqual('CLOSED', protocol.state) + + @socket_helper.skip_unless_bind_unix_socket + def test_create_datagram_endpoint_existing_sock_unix(self): + with test_utils.unix_socket_path() as path: + sock = socket.socket(socket.AF_UNIX, type=socket.SOCK_DGRAM) + sock.bind(path) + sock.close() + + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(create_future=True, loop=self.loop), + path, family=socket.AF_UNIX) + transport, protocol = self.loop.run_until_complete(coro) + transport.close() + self.loop.run_until_complete(protocol.done) + + def test_create_datagram_endpoint_sock_sockopts(self): + class FakeSock: + type = socket.SOCK_DGRAM + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('127.0.0.1', 0), sock=FakeSock()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, remote_addr=('127.0.0.1', 0), sock=FakeSock()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, family=1, sock=FakeSock()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, proto=1, sock=FakeSock()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, flags=1, sock=FakeSock()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, reuse_port=True, sock=FakeSock()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, allow_broadcast=True, sock=FakeSock()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + @unittest.skipIf(sys.platform == 'vxworks', + "SO_BROADCAST is enabled by default on VxWorks") + def test_create_datagram_endpoint_sockopts(self): + # Socket options should not be applied unless asked for. + # SO_REUSEPORT is not available on all platforms. + + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(create_future=True, loop=self.loop), + local_addr=('127.0.0.1', 0)) + transport, protocol = self.loop.run_until_complete(coro) + sock = transport.get_extra_info('socket') + + reuseport_supported = hasattr(socket, 'SO_REUSEPORT') + + if reuseport_supported: + self.assertFalse( + sock.getsockopt( + socket.SOL_SOCKET, socket.SO_REUSEPORT)) + self.assertFalse( + sock.getsockopt( + socket.SOL_SOCKET, socket.SO_BROADCAST)) + + transport.close() + self.loop.run_until_complete(protocol.done) + self.assertEqual('CLOSED', protocol.state) + + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(create_future=True, loop=self.loop), + local_addr=('127.0.0.1', 0), + reuse_port=reuseport_supported, + allow_broadcast=True) + transport, protocol = self.loop.run_until_complete(coro) + sock = transport.get_extra_info('socket') + + self.assertFalse( + sock.getsockopt( + socket.SOL_SOCKET, socket.SO_REUSEADDR)) + if reuseport_supported: + self.assertTrue( + sock.getsockopt( + socket.SOL_SOCKET, socket.SO_REUSEPORT)) + self.assertTrue( + sock.getsockopt( + socket.SOL_SOCKET, socket.SO_BROADCAST)) + + transport.close() + self.loop.run_until_complete(protocol.done) + self.assertEqual('CLOSED', protocol.state) + + @patch_socket + def test_create_datagram_endpoint_nosoreuseport(self, m_socket): + del m_socket.SO_REUSEPORT + m_socket.socket.return_value = mock.Mock() + + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(loop=self.loop), + local_addr=('127.0.0.1', 0), + reuse_port=True) + + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + @patch_socket + def test_create_datagram_endpoint_ip_addr(self, m_socket): + def getaddrinfo(*args, **kw): + self.fail('should not have called getaddrinfo') + + m_socket.getaddrinfo = getaddrinfo + m_socket.socket.return_value.bind = bind = mock.Mock() + self.loop._add_reader = mock.Mock() + + reuseport_supported = hasattr(socket, 'SO_REUSEPORT') + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(loop=self.loop), + local_addr=('1.2.3.4', 0), + reuse_port=reuseport_supported) + + t, p = self.loop.run_until_complete(coro) + try: + bind.assert_called_with(('1.2.3.4', 0)) + m_socket.socket.assert_called_with(family=m_socket.AF_INET, + proto=m_socket.IPPROTO_UDP, + type=m_socket.SOCK_DGRAM) + finally: + t.close() + test_utils.run_briefly(self.loop) # allow transport to close + + def test_accept_connection_retry(self): + sock = mock.Mock() + sock.accept.side_effect = BlockingIOError() + + self.loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + @mock.patch('asyncio.base_events.logger') + def test_accept_connection_exception(self, m_log): + sock = mock.Mock() + sock.fileno.return_value = 10 + sock.accept.side_effect = OSError(errno.EMFILE, 'Too many open files') + self.loop._remove_reader = mock.Mock() + self.loop.call_later = mock.Mock() + + self.loop._accept_connection(MyProto, sock) + self.assertTrue(m_log.error.called) + self.assertFalse(sock.close.called) + self.loop._remove_reader.assert_called_with(10) + self.loop.call_later.assert_called_with( + constants.ACCEPT_RETRY_DELAY, + # self.loop._start_serving + mock.ANY, + MyProto, sock, None, None, mock.ANY, mock.ANY, mock.ANY) + + def test_call_coroutine(self): + async def simple_coroutine(): + pass + + self.loop.set_debug(True) + coro_func = simple_coroutine + coro_obj = coro_func() + self.addCleanup(coro_obj.close) + for func in (coro_func, coro_obj): + with self.assertRaises(TypeError): + self.loop.call_soon(func) + with self.assertRaises(TypeError): + self.loop.call_soon_threadsafe(func) + with self.assertRaises(TypeError): + self.loop.call_later(60, func) + with self.assertRaises(TypeError): + self.loop.call_at(self.loop.time() + 60, func) + with self.assertRaises(TypeError): + self.loop.run_until_complete( + self.loop.run_in_executor(None, func)) + + @mock.patch('asyncio.base_events.logger') + def test_log_slow_callbacks(self, m_logger): + def stop_loop_cb(loop): + loop.stop() + + async def stop_loop_coro(loop): + loop.stop() + + asyncio.set_event_loop(self.loop) + self.loop.set_debug(True) + self.loop.slow_callback_duration = 0.0 + + # slow callback + self.loop.call_soon(stop_loop_cb, self.loop) + self.loop.run_forever() + fmt, *args = m_logger.warning.call_args[0] + self.assertRegex(fmt % tuple(args), + "^Executing " + "took .* seconds$") + + # slow task + asyncio.ensure_future(stop_loop_coro(self.loop), loop=self.loop) + self.loop.run_forever() + fmt, *args = m_logger.warning.call_args[0] + self.assertRegex(fmt % tuple(args), + "^Executing " + "took .* seconds$") + + +class RunningLoopTests(unittest.TestCase): + + def test_running_loop_within_a_loop(self): + async def runner(loop): + loop.run_forever() + + loop = asyncio.new_event_loop() + outer_loop = asyncio.new_event_loop() + try: + with self.assertRaisesRegex(RuntimeError, + 'while another loop is running'): + outer_loop.run_until_complete(runner(loop)) + finally: + loop.close() + outer_loop.close() + + +class BaseLoopSockSendfileTests(test_utils.TestCase): + + DATA = b"12345abcde" * 16 * 1024 # 160 KiB + + class MyProto(asyncio.Protocol): + + def __init__(self, loop): + self.started = False + self.closed = False + self.data = bytearray() + self.fut = loop.create_future() + self.transport = None + + def connection_made(self, transport): + self.started = True + self.transport = transport + + def data_received(self, data): + self.data.extend(data) + + def connection_lost(self, exc): + self.closed = True + self.fut.set_result(None) + self.transport = None + + async def wait_closed(self): + await self.fut + + @classmethod + def setUpClass(cls): + cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE + constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16 + with open(os_helper.TESTFN, 'wb') as fp: + fp.write(cls.DATA) + super().setUpClass() + + @classmethod + def tearDownClass(cls): + constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize + os_helper.unlink(os_helper.TESTFN) + super().tearDownClass() + + def setUp(self): + from asyncio.selector_events import BaseSelectorEventLoop + # BaseSelectorEventLoop() has no native implementation + self.loop = BaseSelectorEventLoop() + self.set_event_loop(self.loop) + self.file = open(os_helper.TESTFN, 'rb') + self.addCleanup(self.file.close) + super().setUp() + + def make_socket(self, blocking=False): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setblocking(blocking) + self.addCleanup(sock.close) + return sock + + def run_loop(self, coro): + return self.loop.run_until_complete(coro) + + def prepare(self): + sock = self.make_socket() + proto = self.MyProto(self.loop) + server = self.run_loop(self.loop.create_server( + lambda: proto, socket_helper.HOST, 0, family=socket.AF_INET)) + addr = server.sockets[0].getsockname() + + for _ in range(10): + try: + self.run_loop(self.loop.sock_connect(sock, addr)) + except OSError: + self.run_loop(asyncio.sleep(0.5)) + continue + else: + break + else: + # One last try, so we get the exception + self.run_loop(self.loop.sock_connect(sock, addr)) + + def cleanup(): + server.close() + sock.close() + if proto.transport is not None: + proto.transport.close() + self.run_loop(proto.wait_closed()) + self.run_loop(server.wait_closed()) + + self.addCleanup(cleanup) + + return sock, proto + + def test__sock_sendfile_native_failure(self): + sock, proto = self.prepare() + + with self.assertRaisesRegex(asyncio.SendfileNotAvailableError, + "sendfile is not available"): + self.run_loop(self.loop._sock_sendfile_native(sock, self.file, + 0, None)) + + self.assertEqual(proto.data, b'') + self.assertEqual(self.file.tell(), 0) + + def test_sock_sendfile_no_fallback(self): + sock, proto = self.prepare() + + with self.assertRaisesRegex(asyncio.SendfileNotAvailableError, + "sendfile is not available"): + self.run_loop(self.loop.sock_sendfile(sock, self.file, + fallback=False)) + + self.assertEqual(self.file.tell(), 0) + self.assertEqual(proto.data, b'') + + def test_sock_sendfile_fallback(self): + sock, proto = self.prepare() + + ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) + sock.close() + self.run_loop(proto.wait_closed()) + + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(self.file.tell(), len(self.DATA)) + self.assertEqual(proto.data, self.DATA) + + def test_sock_sendfile_fallback_offset_and_count(self): + sock, proto = self.prepare() + + ret = self.run_loop(self.loop.sock_sendfile(sock, self.file, + 1000, 2000)) + sock.close() + self.run_loop(proto.wait_closed()) + + self.assertEqual(ret, 2000) + self.assertEqual(self.file.tell(), 3000) + self.assertEqual(proto.data, self.DATA[1000:3000]) + + def test_blocking_socket(self): + self.loop.set_debug(True) + sock = self.make_socket(blocking=True) + with self.assertRaisesRegex(ValueError, "must be non-blocking"): + self.run_loop(self.loop.sock_sendfile(sock, self.file)) + + def test_nonbinary_file(self): + sock = self.make_socket() + with open(os_helper.TESTFN, encoding="utf-8") as f: + with self.assertRaisesRegex(ValueError, "binary mode"): + self.run_loop(self.loop.sock_sendfile(sock, f)) + + def test_nonstream_socket(self): + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.setblocking(False) + self.addCleanup(sock.close) + with self.assertRaisesRegex(ValueError, "only SOCK_STREAM type"): + self.run_loop(self.loop.sock_sendfile(sock, self.file)) + + def test_notint_count(self): + sock = self.make_socket() + with self.assertRaisesRegex(TypeError, + "count must be a positive integer"): + self.run_loop(self.loop.sock_sendfile(sock, self.file, 0, 'count')) + + def test_negative_count(self): + sock = self.make_socket() + with self.assertRaisesRegex(ValueError, + "count must be a positive integer"): + self.run_loop(self.loop.sock_sendfile(sock, self.file, 0, -1)) + + def test_notint_offset(self): + sock = self.make_socket() + with self.assertRaisesRegex(TypeError, + "offset must be a non-negative integer"): + self.run_loop(self.loop.sock_sendfile(sock, self.file, 'offset')) + + def test_negative_offset(self): + sock = self.make_socket() + with self.assertRaisesRegex(ValueError, + "offset must be a non-negative integer"): + self.run_loop(self.loop.sock_sendfile(sock, self.file, -1)) + + +class TestSelectorUtils(test_utils.TestCase): + def check_set_nodelay(self, sock): + opt = sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + self.assertFalse(opt) + + base_events._set_nodelay(sock) + + opt = sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + self.assertTrue(opt) + + @unittest.skipUnless(hasattr(socket, 'TCP_NODELAY'), + 'need socket.TCP_NODELAY') + def test_set_nodelay(self): + sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM, + proto=socket.IPPROTO_TCP) + with sock: + self.check_set_nodelay(sock) + + sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM, + proto=socket.IPPROTO_TCP) + with sock: + sock.setblocking(False) + self.check_set_nodelay(sock) + + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_buffered_proto.py b/Lib/test/test_asyncio/test_buffered_proto.py new file mode 100644 index 00000000000..f24e363ebfc --- /dev/null +++ b/Lib/test/test_asyncio/test_buffered_proto.py @@ -0,0 +1,89 @@ +import asyncio +import unittest + +from test.test_asyncio import functional as func_tests + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class ReceiveStuffProto(asyncio.BufferedProtocol): + def __init__(self, cb, con_lost_fut): + self.cb = cb + self.con_lost_fut = con_lost_fut + + def get_buffer(self, sizehint): + self.buffer = bytearray(100) + return self.buffer + + def buffer_updated(self, nbytes): + self.cb(self.buffer[:nbytes]) + + def connection_lost(self, exc): + if exc is None: + self.con_lost_fut.set_result(None) + else: + self.con_lost_fut.set_exception(exc) + + +class BaseTestBufferedProtocol(func_tests.FunctionalTestCaseMixin): + + def new_loop(self): + raise NotImplementedError + + def test_buffered_proto_create_connection(self): + + NOISE = b'12345678+' * 1024 + + async def client(addr): + data = b'' + + def on_buf(buf): + nonlocal data + data += buf + if data == NOISE: + tr.write(b'1') + + conn_lost_fut = self.loop.create_future() + + tr, pr = await self.loop.create_connection( + lambda: ReceiveStuffProto(on_buf, conn_lost_fut), *addr) + + await conn_lost_fut + + async def on_server_client(reader, writer): + writer.write(NOISE) + await reader.readexactly(1) + writer.close() + await writer.wait_closed() + + srv = self.loop.run_until_complete( + asyncio.start_server( + on_server_client, '127.0.0.1', 0)) + + addr = srv.sockets[0].getsockname() + self.loop.run_until_complete( + asyncio.wait_for(client(addr), 5)) + + srv.close() + self.loop.run_until_complete(srv.wait_closed()) + + +class BufferedProtocolSelectorTests(BaseTestBufferedProtocol, + unittest.TestCase): + + def new_loop(self): + return asyncio.SelectorEventLoop() + + +@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only') +class BufferedProtocolProactorTests(BaseTestBufferedProtocol, + unittest.TestCase): + + def new_loop(self): + return asyncio.ProactorEventLoop() + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_context.py b/Lib/test/test_asyncio/test_context.py new file mode 100644 index 00000000000..f635bd0293a --- /dev/null +++ b/Lib/test/test_asyncio/test_context.py @@ -0,0 +1,41 @@ +import asyncio +import decimal +import unittest + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +@unittest.skipUnless(decimal.HAVE_CONTEXTVAR, "decimal is built with a thread-local context") +class DecimalContextTest(unittest.TestCase): + + # TODO: RUSTPYTHON + # AssertionError: '0.111111' != '0.111' + @unittest.expectedFailure + def test_asyncio_task_decimal_context(self): + async def fractions(t, precision, x, y): + with decimal.localcontext() as ctx: + ctx.prec = precision + a = decimal.Decimal(x) / decimal.Decimal(y) + await asyncio.sleep(t) + b = decimal.Decimal(x) / decimal.Decimal(y ** 2) + return a, b + + async def main(): + r1, r2 = await asyncio.gather( + fractions(0.1, 3, 1, 3), fractions(0.2, 6, 1, 3)) + + return r1, r2 + + r1, r2 = asyncio.run(main()) + + self.assertEqual(str(r1[0]), '0.333') + self.assertEqual(str(r1[1]), '0.111') + + self.assertEqual(str(r2[0]), '0.333333') + self.assertEqual(str(r2[1]), '0.111111') + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_eager_task_factory.py b/Lib/test/test_asyncio/test_eager_task_factory.py new file mode 100644 index 00000000000..687384012e4 --- /dev/null +++ b/Lib/test/test_asyncio/test_eager_task_factory.py @@ -0,0 +1,438 @@ +"""Tests for base_events.py""" + +import asyncio +import contextvars +import unittest + +from unittest import mock +from asyncio import tasks +from test.test_asyncio import utils as test_utils +from test.support.script_helper import assert_python_ok + +MOCK_ANY = mock.ANY + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class EagerTaskFactoryLoopTests: + + Task = None + + def run_coro(self, coro): + """ + Helper method to run the `coro` coroutine in the test event loop. + It helps with making sure the event loop is running before starting + to execute `coro`. This is important for testing the eager step + functionality, since an eager step is taken only if the event loop + is already running. + """ + + async def coro_runner(): + self.assertTrue(asyncio.get_event_loop().is_running()) + return await coro + + return self.loop.run_until_complete(coro) + + def setUp(self): + super().setUp() + self.loop = asyncio.new_event_loop() + self.eager_task_factory = asyncio.create_eager_task_factory(self.Task) + self.loop.set_task_factory(self.eager_task_factory) + self.set_event_loop(self.loop) + + def test_eager_task_factory_set(self): + self.assertIsNotNone(self.eager_task_factory) + self.assertIs(self.loop.get_task_factory(), self.eager_task_factory) + + async def noop(): pass + + async def run(): + t = self.loop.create_task(noop()) + self.assertIsInstance(t, self.Task) + await t + + self.run_coro(run()) + + def test_await_future_during_eager_step(self): + + async def set_result(fut, val): + fut.set_result(val) + + async def run(): + fut = self.loop.create_future() + t = self.loop.create_task(set_result(fut, 'my message')) + # assert the eager step completed the task + self.assertTrue(t.done()) + return await fut + + self.assertEqual(self.run_coro(run()), 'my message') + + def test_eager_completion(self): + + async def coro(): + return 'hello' + + async def run(): + t = self.loop.create_task(coro()) + # assert the eager step completed the task + self.assertTrue(t.done()) + return await t + + self.assertEqual(self.run_coro(run()), 'hello') + + def test_block_after_eager_step(self): + + async def coro(): + await asyncio.sleep(0.1) + return 'finished after blocking' + + async def run(): + t = self.loop.create_task(coro()) + self.assertFalse(t.done()) + result = await t + self.assertTrue(t.done()) + return result + + self.assertEqual(self.run_coro(run()), 'finished after blocking') + + def test_cancellation_after_eager_completion(self): + + async def coro(): + return 'finished without blocking' + + async def run(): + t = self.loop.create_task(coro()) + t.cancel() + result = await t + # finished task can't be cancelled + self.assertFalse(t.cancelled()) + return result + + self.assertEqual(self.run_coro(run()), 'finished without blocking') + + def test_cancellation_after_eager_step_blocks(self): + + async def coro(): + await asyncio.sleep(0.1) + return 'finished after blocking' + + async def run(): + t = self.loop.create_task(coro()) + t.cancel('cancellation message') + self.assertGreater(t.cancelling(), 0) + result = await t + + with self.assertRaises(asyncio.CancelledError) as cm: + self.run_coro(run()) + + self.assertEqual('cancellation message', cm.exception.args[0]) + + def test_current_task(self): + captured_current_task = None + + async def coro(): + nonlocal captured_current_task + captured_current_task = asyncio.current_task() + # verify the task before and after blocking is identical + await asyncio.sleep(0.1) + self.assertIs(asyncio.current_task(), captured_current_task) + + async def run(): + t = self.loop.create_task(coro()) + self.assertIs(captured_current_task, t) + await t + + self.run_coro(run()) + captured_current_task = None + + def test_all_tasks_with_eager_completion(self): + captured_all_tasks = None + + async def coro(): + nonlocal captured_all_tasks + captured_all_tasks = asyncio.all_tasks() + + async def run(): + t = self.loop.create_task(coro()) + self.assertIn(t, captured_all_tasks) + self.assertNotIn(t, asyncio.all_tasks()) + + self.run_coro(run()) + + def test_all_tasks_with_blocking(self): + captured_eager_all_tasks = None + + async def coro(fut1, fut2): + nonlocal captured_eager_all_tasks + captured_eager_all_tasks = asyncio.all_tasks() + await fut1 + fut2.set_result(None) + + async def run(): + fut1 = self.loop.create_future() + fut2 = self.loop.create_future() + t = self.loop.create_task(coro(fut1, fut2)) + self.assertIn(t, captured_eager_all_tasks) + self.assertIn(t, asyncio.all_tasks()) + fut1.set_result(None) + await fut2 + self.assertNotIn(t, asyncio.all_tasks()) + + self.run_coro(run()) + + # TODO: RUSTPYTHON + # AssertionError: 2 != 1 + @unittest.expectedFailure + def test_context_vars(self): + cv = contextvars.ContextVar('cv', default=0) + + coro_first_step_ran = False + coro_second_step_ran = False + + async def coro(): + nonlocal coro_first_step_ran + nonlocal coro_second_step_ran + self.assertEqual(cv.get(), 1) + cv.set(2) + self.assertEqual(cv.get(), 2) + coro_first_step_ran = True + await asyncio.sleep(0.1) + self.assertEqual(cv.get(), 2) + cv.set(3) + self.assertEqual(cv.get(), 3) + coro_second_step_ran = True + + async def run(): + cv.set(1) + t = self.loop.create_task(coro()) + self.assertTrue(coro_first_step_ran) + self.assertFalse(coro_second_step_ran) + self.assertEqual(cv.get(), 1) + await t + self.assertTrue(coro_second_step_ran) + self.assertEqual(cv.get(), 1) + + self.run_coro(run()) + + # TODO: RUSTPYTHON + # assert len(exceptions) == this_index + 1 + @unittest.expectedFailure + def test_staggered_race_with_eager_tasks(self): + # See https://github.com/python/cpython/issues/124309 + + async def fail(): + await asyncio.sleep(0) + raise ValueError("no good") + + async def blocked(): + fut = asyncio.Future() + await fut + + async def run(): + winner, index, excs = await asyncio.staggered.staggered_race( + [ + lambda: blocked(), + lambda: asyncio.sleep(1, result="sleep1"), + lambda: fail() + ], + delay=0.25 + ) + self.assertEqual(winner, 'sleep1') + self.assertEqual(index, 1) + self.assertIsNone(excs[index]) + self.assertIsInstance(excs[0], asyncio.CancelledError) + self.assertIsInstance(excs[2], ValueError) + + self.run_coro(run()) + + @unittest.skip('TODO: RUSTPYTHON') + # Causes a hang + def test_staggered_race_with_eager_tasks_no_delay(self): + # See https://github.com/python/cpython/issues/124309 + async def fail(): + raise ValueError("no good") + + async def run(): + winner, index, excs = await asyncio.staggered.staggered_race( + [ + lambda: fail(), + lambda: asyncio.sleep(1, result="sleep1"), + lambda: asyncio.sleep(0, result="sleep0"), + ], + delay=None + ) + self.assertEqual(winner, 'sleep1') + self.assertEqual(index, 1) + self.assertIsNone(excs[index]) + self.assertIsInstance(excs[0], ValueError) + self.assertEqual(len(excs), 2) + + self.run_coro(run()) + + +class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase): + Task = tasks._PyTask + + +@unittest.skipUnless(hasattr(tasks, '_CTask'), + 'requires the C _asyncio module') +class CEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase): + Task = getattr(tasks, '_CTask', None) + + def test_issue105987(self): + code = """if 1: + from _asyncio import _swap_current_task + + class DummyTask: + pass + + class DummyLoop: + pass + + l = DummyLoop() + _swap_current_task(l, DummyTask()) + t = _swap_current_task(l, None) + """ + + _, out, err = assert_python_ok("-c", code) + self.assertFalse(err) + + def test_issue122332(self): + async def coro(): + pass + + async def run(): + task = self.loop.create_task(coro()) + await task + self.assertIsNone(task.get_coro()) + + self.run_coro(run()) + + def test_name(self): + name = None + async def coro(): + nonlocal name + name = asyncio.current_task().get_name() + + async def main(): + task = self.loop.create_task(coro(), name="test name") + self.assertEqual(name, "test name") + await task + + self.run_coro(coro()) + +class AsyncTaskCounter: + def __init__(self, loop, *, task_class, eager): + self.suspense_count = 0 + self.task_count = 0 + + def CountingTask(*args, eager_start=False, **kwargs): + if not eager_start: + self.task_count += 1 + kwargs["eager_start"] = eager_start + return task_class(*args, **kwargs) + + if eager: + factory = asyncio.create_eager_task_factory(CountingTask) + else: + def factory(loop, coro, **kwargs): + return CountingTask(coro, loop=loop, **kwargs) + loop.set_task_factory(factory) + + def get(self): + return self.task_count + + +async def awaitable_chain(depth): + if depth == 0: + return 0 + return 1 + await awaitable_chain(depth - 1) + + +async def recursive_taskgroups(width, depth): + if depth == 0: + return + + async with asyncio.TaskGroup() as tg: + futures = [ + tg.create_task(recursive_taskgroups(width, depth - 1)) + for _ in range(width) + ] + + +async def recursive_gather(width, depth): + if depth == 0: + return + + await asyncio.gather( + *[recursive_gather(width, depth - 1) for _ in range(width)] + ) + + +class BaseTaskCountingTests: + + Task = None + eager = None + expected_task_count = None + + def setUp(self): + super().setUp() + self.loop = asyncio.new_event_loop() + self.counter = AsyncTaskCounter(self.loop, task_class=self.Task, eager=self.eager) + self.set_event_loop(self.loop) + + def test_awaitables_chain(self): + observed_depth = self.loop.run_until_complete(awaitable_chain(100)) + self.assertEqual(observed_depth, 100) + self.assertEqual(self.counter.get(), 0 if self.eager else 1) + + def test_recursive_taskgroups(self): + num_tasks = self.loop.run_until_complete(recursive_taskgroups(5, 4)) + self.assertEqual(self.counter.get(), self.expected_task_count) + + def test_recursive_gather(self): + self.loop.run_until_complete(recursive_gather(5, 4)) + self.assertEqual(self.counter.get(), self.expected_task_count) + + +class BaseNonEagerTaskFactoryTests(BaseTaskCountingTests): + eager = False + expected_task_count = 781 # 1 + 5 + 5^2 + 5^3 + 5^4 + + +class BaseEagerTaskFactoryTests(BaseTaskCountingTests): + eager = True + expected_task_count = 0 + + +class NonEagerTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase): + Task = asyncio.Task + + +class EagerTests(BaseEagerTaskFactoryTests, test_utils.TestCase): + Task = asyncio.Task + + +class NonEagerPyTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase): + Task = tasks._PyTask + + +class EagerPyTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase): + Task = tasks._PyTask + + +@unittest.skipUnless(hasattr(tasks, '_CTask'), + 'requires the C _asyncio module') +class NonEagerCTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase): + Task = getattr(tasks, '_CTask', None) + + +@unittest.skipUnless(hasattr(tasks, '_CTask'), + 'requires the C _asyncio module') +class EagerCTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase): + Task = getattr(tasks, '_CTask', None) + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py new file mode 100644 index 00000000000..84dde0d9b34 --- /dev/null +++ b/Lib/test/test_asyncio/test_events.py @@ -0,0 +1,3067 @@ +"""Tests for events.py.""" + +import concurrent.futures +import contextlib +import functools +import io +import multiprocessing +import os +import platform +import re +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import subprocess +import sys +import threading +import time +import types +import errno +import unittest +from unittest import mock +import weakref +import warnings +if sys.platform not in ('win32', 'vxworks'): + import tty + +import asyncio +from asyncio import coroutines +from asyncio import events +from asyncio import selector_events +from multiprocessing.util import _cleanup_tests as multiprocessing_cleanup_tests +from test.test_asyncio import utils as test_utils +from test import support +from test.support import socket_helper +from test.support import threading_helper +from test.support import ALWAYS_EQ, LARGEST, SMALLEST + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +def broken_unix_getsockname(): + """Return True if the platform is Mac OS 10.4 or older.""" + if sys.platform.startswith("aix"): + return True + elif sys.platform != 'darwin': + return False + version = platform.mac_ver()[0] + version = tuple(map(int, version.split('.'))) + return version < (10, 5) + + +def _test_get_event_loop_new_process__sub_proc(): + async def doit(): + return 'hello' + + with contextlib.closing(asyncio.new_event_loop()) as loop: + asyncio.set_event_loop(loop) + return loop.run_until_complete(doit()) + + +class CoroLike: + def send(self, v): + pass + + def throw(self, *exc): + pass + + def close(self): + pass + + def __await__(self): + pass + + +class MyBaseProto(asyncio.Protocol): + connected = None + done = None + + def __init__(self, loop=None): + self.transport = None + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.connected = loop.create_future() + self.done = loop.create_future() + + def _assert_state(self, *expected): + if self.state not in expected: + raise AssertionError(f'state: {self.state!r}, expected: {expected!r}') + + def connection_made(self, transport): + self.transport = transport + self._assert_state('INITIAL') + self.state = 'CONNECTED' + if self.connected: + self.connected.set_result(None) + + def data_received(self, data): + self._assert_state('CONNECTED') + self.nbytes += len(data) + + def eof_received(self): + self._assert_state('CONNECTED') + self.state = 'EOF' + + def connection_lost(self, exc): + self._assert_state('CONNECTED', 'EOF') + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyProto(MyBaseProto): + def connection_made(self, transport): + super().connection_made(transport) + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + +class MyDatagramProto(asyncio.DatagramProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = loop.create_future() + + def _assert_state(self, expected): + if self.state != expected: + raise AssertionError(f'state: {self.state!r}, expected: {expected!r}') + + def connection_made(self, transport): + self.transport = transport + self._assert_state('INITIAL') + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + self._assert_state('INITIALIZED') + self.nbytes += len(data) + + def error_received(self, exc): + self._assert_state('INITIALIZED') + + def connection_lost(self, exc): + self._assert_state('INITIALIZED') + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyReadPipeProto(asyncio.Protocol): + done = None + + def __init__(self, loop=None): + self.state = ['INITIAL'] + self.nbytes = 0 + self.transport = None + if loop is not None: + self.done = loop.create_future() + + def _assert_state(self, expected): + if self.state != expected: + raise AssertionError(f'state: {self.state!r}, expected: {expected!r}') + + def connection_made(self, transport): + self.transport = transport + self._assert_state(['INITIAL']) + self.state.append('CONNECTED') + + def data_received(self, data): + self._assert_state(['INITIAL', 'CONNECTED']) + self.nbytes += len(data) + + def eof_received(self): + self._assert_state(['INITIAL', 'CONNECTED']) + self.state.append('EOF') + + def connection_lost(self, exc): + if 'EOF' not in self.state: + self.state.append('EOF') # It is okay if EOF is missed. + self._assert_state(['INITIAL', 'CONNECTED', 'EOF']) + self.state.append('CLOSED') + if self.done: + self.done.set_result(None) + + +class MyWritePipeProto(asyncio.BaseProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.transport = None + if loop is not None: + self.done = loop.create_future() + + def _assert_state(self, expected): + if self.state != expected: + raise AssertionError(f'state: {self.state!r}, expected: {expected!r}') + + def connection_made(self, transport): + self.transport = transport + self._assert_state('INITIAL') + self.state = 'CONNECTED' + + def connection_lost(self, exc): + self._assert_state('CONNECTED') + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MySubprocessProtocol(asyncio.SubprocessProtocol): + + def __init__(self, loop): + self.state = 'INITIAL' + self.transport = None + self.connected = loop.create_future() + self.completed = loop.create_future() + self.disconnects = {fd: loop.create_future() for fd in range(3)} + self.data = {1: b'', 2: b''} + self.returncode = None + self.got_data = {1: asyncio.Event(), + 2: asyncio.Event()} + + def _assert_state(self, expected): + if self.state != expected: + raise AssertionError(f'state: {self.state!r}, expected: {expected!r}') + + def connection_made(self, transport): + self.transport = transport + self._assert_state('INITIAL') + self.state = 'CONNECTED' + self.connected.set_result(None) + + def connection_lost(self, exc): + self._assert_state('CONNECTED') + self.state = 'CLOSED' + self.completed.set_result(None) + + def pipe_data_received(self, fd, data): + self._assert_state('CONNECTED') + self.data[fd] += data + self.got_data[fd].set() + + def pipe_connection_lost(self, fd, exc): + self._assert_state('CONNECTED') + if exc: + self.disconnects[fd].set_exception(exc) + else: + self.disconnects[fd].set_result(exc) + + def process_exited(self): + self._assert_state('CONNECTED') + self.returncode = self.transport.get_returncode() + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.loop = self.create_event_loop() + self.set_event_loop(self.loop) + + def tearDown(self): + # just in case if we have transport close callbacks + if not self.loop.is_closed(): + test_utils.run_briefly(self.loop) + + self.doCleanups() + support.gc_collect() + super().tearDown() + + # TODO: RUSTPYTHON + # AssertionError: RuntimeWarning not triggered + @unittest.expectedFailure + def test_run_until_complete_nesting(self): + async def coro1(): + await asyncio.sleep(0) + + async def coro2(): + self.assertTrue(self.loop.is_running()) + self.loop.run_until_complete(coro1()) + + with self.assertWarnsRegex( + RuntimeWarning, + r"coroutine \S+ was never awaited" + ): + self.assertRaises( + RuntimeError, self.loop.run_until_complete, coro2()) + + # Note: because of the default Windows timing granularity of + # 15.6 msec, we use fairly long sleep times here (~100 msec). + + def test_run_until_complete(self): + delay = 0.100 + t0 = self.loop.time() + self.loop.run_until_complete(asyncio.sleep(delay)) + dt = self.loop.time() - t0 + self.assertGreaterEqual(dt, delay - test_utils.CLOCK_RES) + + def test_run_until_complete_stopped(self): + + async def cb(): + self.loop.stop() + await asyncio.sleep(0.1) + task = cb() + self.assertRaises(RuntimeError, + self.loop.run_until_complete, task) + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + self.loop.stop() + + self.loop.call_later(0.1, callback, 'hello world') + self.loop.run_forever() + self.assertEqual(results, ['hello world']) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + self.loop.stop() + + self.loop.call_soon(callback, 'hello', 'world') + self.loop.run_forever() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_threadsafe(self): + results = [] + lock = threading.Lock() + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + def run_in_thread(): + self.loop.call_soon_threadsafe(callback, 'hello') + lock.release() + + lock.acquire() + t = threading.Thread(target=run_in_thread) + t.start() + + with lock: + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + t.join() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + self.loop.call_soon_threadsafe(callback, 'hello') + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + self.assertEqual(results, ['hello', 'world']) + + def test_run_in_executor(self): + def run(arg): + return (arg, threading.get_ident()) + f2 = self.loop.run_in_executor(None, run, 'yo') + res, thread_id = self.loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + self.assertNotEqual(thread_id, threading.get_ident()) + + def test_run_in_executor_cancel(self): + called = False + + def patched_call_soon(*args): + nonlocal called + called = True + + def run(): + time.sleep(0.05) + + f2 = self.loop.run_in_executor(None, run) + f2.cancel() + self.loop.run_until_complete( + self.loop.shutdown_default_executor()) + self.loop.close() + self.loop.call_soon = patched_call_soon + self.loop.call_soon_threadsafe = patched_call_soon + time.sleep(0.4) + self.assertFalse(called) + + def test_reader_callback(self): + r, w = socket.socketpair() + r.setblocking(False) + bytes_read = bytearray() + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.extend(data) + else: + self.assertTrue(self.loop.remove_reader(r.fileno())) + r.close() + + self.loop.add_reader(r.fileno(), reader) + self.loop.call_soon(w.send, b'abc') + test_utils.run_until(self.loop, lambda: len(bytes_read) >= 3) + self.loop.call_soon(w.send, b'def') + test_utils.run_until(self.loop, lambda: len(bytes_read) >= 6) + self.loop.call_soon(w.close) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + self.assertEqual(bytes_read, b'abcdef') + + def test_writer_callback(self): + r, w = socket.socketpair() + w.setblocking(False) + + def writer(data): + w.send(data) + self.loop.stop() + + data = b'x' * 1024 + self.loop.add_writer(w.fileno(), writer, data) + self.loop.run_forever() + + self.assertTrue(self.loop.remove_writer(w.fileno())) + self.assertFalse(self.loop.remove_writer(w.fileno())) + + w.close() + read = r.recv(len(data) * 2) + r.close() + self.assertEqual(read, data) + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + @unittest.skip('TODO: RUSTPYTHON') + # OSError: Failed to set signal + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.loop.add_signal_handler(signal.SIGINT, my_handler) + + os.kill(os.getpid(), signal.SIGINT) + test_utils.run_until(self.loop, lambda: caught) + + # Removing it should restore the default handler. + self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + @unittest.skipUnless(hasattr(signal, 'setitimer'), + 'need signal.setitimer()') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + self.loop.stop() + + self.loop.add_signal_handler(signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.call_later(60, self.loop.stop) + self.loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + @unittest.skipUnless(hasattr(signal, 'setitimer'), + 'need signal.setitimer()') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + self.loop.stop() + + self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.loop.call_later(60, self.loop.stop) + self.loop.run_forever() + self.assertEqual(caught, 1) + + def _basetest_create_connection(self, connection_fut, check_sockname=True): + tr, pr = self.loop.run_until_complete(connection_fut) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) + self.assertIs(pr.transport, tr) + if check_sockname: + self.assertIsNotNone(tr.get_extra_info('sockname')) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection(self): + with test_utils.run_test_server() as httpd: + conn_fut = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address) + self._basetest_create_connection(conn_fut) + + @socket_helper.skip_unless_bind_unix_socket + def test_create_unix_connection(self): + # Issue #20682: On Mac OS X Tiger, getsockname() returns a + # zero-length address for UNIX socket. + check_sockname = not broken_unix_getsockname() + + with test_utils.run_test_unix_server() as httpd: + conn_fut = self.loop.create_unix_connection( + lambda: MyProto(loop=self.loop), httpd.address) + self._basetest_create_connection(conn_fut, check_sockname) + + def check_ssl_extra_info(self, client, check_sockname=True, + peername=None, peercert={}): + if check_sockname: + self.assertIsNotNone(client.get_extra_info('sockname')) + if peername: + self.assertEqual(peername, + client.get_extra_info('peername')) + else: + self.assertIsNotNone(client.get_extra_info('peername')) + self.assertEqual(peercert, + client.get_extra_info('peercert')) + + # test SSL cipher + cipher = client.get_extra_info('cipher') + self.assertIsInstance(cipher, tuple) + self.assertEqual(len(cipher), 3, cipher) + self.assertIsInstance(cipher[0], str) + self.assertIsInstance(cipher[1], str) + self.assertIsInstance(cipher[2], int) + + # test SSL object + sslobj = client.get_extra_info('ssl_object') + self.assertIsNotNone(sslobj) + self.assertEqual(sslobj.compression(), + client.get_extra_info('compression')) + self.assertEqual(sslobj.cipher(), + client.get_extra_info('cipher')) + self.assertEqual(sslobj.getpeercert(), + client.get_extra_info('peercert')) + self.assertEqual(sslobj.compression(), + client.get_extra_info('compression')) + + def _basetest_create_ssl_connection(self, connection_fut, + check_sockname=True, + peername=None): + tr, pr = self.loop.run_until_complete(connection_fut) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.check_ssl_extra_info(tr, check_sockname, peername) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def _test_create_ssl_connection(self, httpd, create_connection, + check_sockname=True, peername=None): + conn_fut = create_connection(ssl=test_utils.dummy_ssl_context()) + self._basetest_create_ssl_connection(conn_fut, check_sockname, + peername) + + # ssl.Purpose was introduced in Python 3.4 + if hasattr(ssl, 'Purpose'): + def _dummy_ssl_create_context(purpose=ssl.Purpose.SERVER_AUTH, *, + cafile=None, capath=None, + cadata=None): + """ + A ssl.create_default_context() replacement that doesn't enable + cert validation. + """ + self.assertEqual(purpose, ssl.Purpose.SERVER_AUTH) + return test_utils.dummy_ssl_context() + + # With ssl=True, ssl.create_default_context() should be called + with mock.patch('ssl.create_default_context', + side_effect=_dummy_ssl_create_context) as m: + conn_fut = create_connection(ssl=True) + self._basetest_create_ssl_connection(conn_fut, check_sockname, + peername) + self.assertEqual(m.call_count, 1) + + # With the real ssl.create_default_context(), certificate + # validation will fail + with self.assertRaises(ssl.SSLError) as cm: + conn_fut = create_connection(ssl=True) + # Ignore the "SSL handshake failed" log in debug mode + with test_utils.disable_logger(): + self._basetest_create_ssl_connection(conn_fut, check_sockname, + peername) + + self.assertEqual(cm.exception.reason, 'CERTIFICATE_VERIFY_FAILED') + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + with test_utils.run_test_server(use_ssl=True) as httpd: + create_connection = functools.partial( + self.loop.create_connection, + lambda: MyProto(loop=self.loop), + *httpd.address) + self._test_create_ssl_connection(httpd, create_connection, + peername=httpd.address) + + @socket_helper.skip_unless_bind_unix_socket + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_unix_connection(self): + # Issue #20682: On Mac OS X Tiger, getsockname() returns a + # zero-length address for UNIX socket. + check_sockname = not broken_unix_getsockname() + + with test_utils.run_test_unix_server(use_ssl=True) as httpd: + create_connection = functools.partial( + self.loop.create_unix_connection, + lambda: MyProto(loop=self.loop), httpd.address, + server_hostname='127.0.0.1') + + self._test_create_ssl_connection(httpd, create_connection, + check_sockname, + peername=httpd.address) + + def test_create_connection_local_addr(self): + with test_utils.run_test_server() as httpd: + port = socket_helper.find_unused_port() + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=(httpd.address[0], port)) + tr, pr = self.loop.run_until_complete(f) + expected = pr.transport.get_extra_info('sockname')[1] + self.assertEqual(port, expected) + tr.close() + + @socket_helper.skip_if_tcp_blackhole + def test_create_connection_local_addr_skip_different_family(self): + # See https://github.com/python/cpython/issues/86508 + port1 = socket_helper.find_unused_port() + port2 = socket_helper.find_unused_port() + getaddrinfo_orig = self.loop.getaddrinfo + + async def getaddrinfo(host, port, *args, **kwargs): + if port == port2: + return [(socket.AF_INET6, socket.SOCK_STREAM, 0, '', ('::1', 0, 0, 0)), + (socket.AF_INET, socket.SOCK_STREAM, 0, '', ('127.0.0.1', 0))] + return await getaddrinfo_orig(host, port, *args, **kwargs) + + self.loop.getaddrinfo = getaddrinfo + + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + 'localhost', port1, local_addr=('localhost', port2)) + + with self.assertRaises(OSError): + self.loop.run_until_complete(f) + + @socket_helper.skip_if_tcp_blackhole + def test_create_connection_local_addr_nomatch_family(self): + # See https://github.com/python/cpython/issues/86508 + port1 = socket_helper.find_unused_port() + port2 = socket_helper.find_unused_port() + getaddrinfo_orig = self.loop.getaddrinfo + + async def getaddrinfo(host, port, *args, **kwargs): + if port == port2: + return [(socket.AF_INET6, socket.SOCK_STREAM, 0, '', ('::1', 0, 0, 0))] + return await getaddrinfo_orig(host, port, *args, **kwargs) + + self.loop.getaddrinfo = getaddrinfo + + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + 'localhost', port1, local_addr=('localhost', port2)) + + with self.assertRaises(OSError): + self.loop.run_until_complete(f) + + def test_create_connection_local_addr_in_use(self): + with test_utils.run_test_server() as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=httpd.address) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + self.assertIn(str(httpd.address), cm.exception.strerror) + + def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None): + loop = self.loop + + class MyProto(MyBaseProto): + + def connection_lost(self, exc): + super().connection_lost(exc) + loop.call_soon(loop.stop) + + def data_received(self, data): + super().data_received(data) + self.transport.write(expected_response) + + lsock = socket.create_server(('127.0.0.1', 0), backlog=1) + addr = lsock.getsockname() + + message = b'test data' + response = None + expected_response = b'roger' + + def client(): + nonlocal response + try: + csock = socket.socket() + if client_ssl is not None: + csock = client_ssl.wrap_socket(csock) + csock.connect(addr) + csock.sendall(message) + response = csock.recv(99) + csock.close() + except Exception as exc: + print( + "Failure in client thread in test_connect_accepted_socket", + exc) + + thread = threading.Thread(target=client, daemon=True) + thread.start() + + conn, _ = lsock.accept() + proto = MyProto(loop=loop) + proto.loop = loop + loop.run_until_complete( + loop.connect_accepted_socket( + (lambda: proto), conn, ssl=server_ssl)) + loop.run_forever() + proto.transport.close() + lsock.close() + + threading_helper.join_thread(thread) + self.assertFalse(thread.is_alive()) + self.assertEqual(proto.state, 'CLOSED') + self.assertEqual(proto.nbytes, len(message)) + self.assertEqual(response, expected_response) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_ssl_connect_accepted_socket(self): + server_context = test_utils.simple_server_sslcontext() + client_context = test_utils.simple_client_sslcontext() + + self.test_connect_accepted_socket(server_context, client_context) + + def test_connect_accepted_socket_ssl_timeout_for_plain_socket(self): + sock = socket.socket() + self.addCleanup(sock.close) + coro = self.loop.connect_accepted_socket( + MyProto, sock, ssl_handshake_timeout=support.LOOPBACK_TIMEOUT) + with self.assertRaisesRegex( + ValueError, + 'ssl_handshake_timeout is only meaningful with ssl'): + self.loop.run_until_complete(coro) + + @mock.patch('asyncio.base_events.socket') + def create_server_multiple_hosts(self, family, hosts, mock_sock): + async def getaddrinfo(host, port, *args, **kw): + if family == socket.AF_INET: + return [(family, socket.SOCK_STREAM, 6, '', (host, port))] + else: + return [(family, socket.SOCK_STREAM, 6, '', (host, port, 0, 0))] + + def getaddrinfo_task(*args, **kwds): + return self.loop.create_task(getaddrinfo(*args, **kwds)) + + unique_hosts = set(hosts) + + if family == socket.AF_INET: + mock_sock.socket().getsockbyname.side_effect = [ + (host, 80) for host in unique_hosts] + else: + mock_sock.socket().getsockbyname.side_effect = [ + (host, 80, 0, 0) for host in unique_hosts] + self.loop.getaddrinfo = getaddrinfo_task + self.loop._start_serving = mock.Mock() + self.loop._stop_serving = mock.Mock() + f = self.loop.create_server(lambda: MyProto(self.loop), hosts, 80) + server = self.loop.run_until_complete(f) + self.addCleanup(server.close) + server_hosts = {sock.getsockbyname()[0] for sock in server.sockets} + self.assertEqual(server_hosts, unique_hosts) + + def test_create_server_multiple_hosts_ipv4(self): + self.create_server_multiple_hosts(socket.AF_INET, + ['1.2.3.4', '5.6.7.8', '1.2.3.4']) + + def test_create_server_multiple_hosts_ipv6(self): + self.create_server_multiple_hosts(socket.AF_INET6, + ['::1', '::2', '::1']) + + def test_create_server(self): + proto = MyProto(self.loop) + f = self.loop.create_server(lambda: proto, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + self.assertEqual(len(server.sockets), 1) + sock = server.sockets[0] + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.sendall(b'xxx') + + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + test_utils.run_until(self.loop, lambda: proto.nbytes > 0) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('sockname')) + self.assertEqual('127.0.0.1', + proto.transport.get_extra_info('peername')[0]) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # close server + server.close() + + def test_create_server_trsock(self): + proto = MyProto(self.loop) + f = self.loop.create_server(lambda: proto, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + self.assertEqual(len(server.sockets), 1) + sock = server.sockets[0] + self.assertIsInstance(sock, asyncio.trsock.TransportSocket) + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + dup = sock.dup() + self.addCleanup(dup.close) + self.assertIsInstance(dup, socket.socket) + self.assertFalse(sock.get_inheritable()) + with self.assertRaises(ValueError): + sock.settimeout(1) + sock.settimeout(0) + self.assertEqual(sock.gettimeout(), 0) + with self.assertRaises(ValueError): + sock.setblocking(True) + sock.setblocking(False) + server.close() + + + @unittest.skipUnless(hasattr(socket, 'SO_REUSEPORT'), 'No SO_REUSEPORT') + def test_create_server_reuse_port(self): + proto = MyProto(self.loop) + f = self.loop.create_server( + lambda: proto, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + self.assertEqual(len(server.sockets), 1) + sock = server.sockets[0] + self.assertFalse( + sock.getsockopt( + socket.SOL_SOCKET, socket.SO_REUSEPORT)) + server.close() + + test_utils.run_briefly(self.loop) + + proto = MyProto(self.loop) + f = self.loop.create_server( + lambda: proto, '0.0.0.0', 0, reuse_port=True) + server = self.loop.run_until_complete(f) + self.assertEqual(len(server.sockets), 1) + sock = server.sockets[0] + self.assertTrue( + sock.getsockopt( + socket.SOL_SOCKET, socket.SO_REUSEPORT)) + server.close() + + def _make_unix_server(self, factory, **kwargs): + path = test_utils.gen_unix_socket_path() + self.addCleanup(lambda: os.path.exists(path) and os.unlink(path)) + + f = self.loop.create_unix_server(factory, path, **kwargs) + server = self.loop.run_until_complete(f) + + return server, path + + @socket_helper.skip_unless_bind_unix_socket + def test_create_unix_server(self): + proto = MyProto(loop=self.loop) + server, path = self._make_unix_server(lambda: proto) + self.assertEqual(len(server.sockets), 1) + + client = socket.socket(socket.AF_UNIX) + client.connect(path) + client.sendall(b'xxx') + + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_until(self.loop, lambda: proto.nbytes > 0) + self.assertEqual(3, proto.nbytes) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # close server + server.close() + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_path_socket_error(self): + proto = MyProto(loop=self.loop) + sock = socket.socket() + with sock: + f = self.loop.create_unix_server(lambda: proto, '/test', sock=sock) + with self.assertRaisesRegex(ValueError, + 'path and sock can not be specified ' + 'at the same time'): + self.loop.run_until_complete(f) + + def _create_ssl_context(self, certfile, keyfile=None): + sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.load_cert_chain(certfile, keyfile) + return sslcontext + + def _make_ssl_server(self, factory, certfile, keyfile=None): + sslcontext = self._create_ssl_context(certfile, keyfile) + + f = self.loop.create_server(factory, '127.0.0.1', 0, ssl=sslcontext) + server = self.loop.run_until_complete(f) + + sock = server.sockets[0] + host, port = sock.getsockname() + self.assertEqual(host, '127.0.0.1') + return server, host, port + + def _make_ssl_unix_server(self, factory, certfile, keyfile=None): + sslcontext = self._create_ssl_context(certfile, keyfile) + return self._make_unix_server(factory, ssl=sslcontext) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, test_utils.ONLYCERT, test_utils.ONLYKEY) + + f_c = self.loop.create_connection(MyBaseProto, host, port, + ssl=test_utils.dummy_ssl_context()) + client, pr = self.loop.run_until_complete(f_c) + + client.write(b'xxx') + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + test_utils.run_until(self.loop, lambda: proto.nbytes > 0) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.check_ssl_extra_info(client, peername=(host, port)) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # stop serving + server.close() + + @socket_helper.skip_unless_bind_unix_socket + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_unix_server_ssl(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, test_utils.ONLYCERT, test_utils.ONLYKEY) + + f_c = self.loop.create_unix_connection( + MyBaseProto, path, ssl=test_utils.dummy_ssl_context(), + server_hostname='') + + client, pr = self.loop.run_until_complete(f_c) + + client.write(b'xxx') + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_until(self.loop, lambda: proto.nbytes > 0) + self.assertEqual(3, proto.nbytes) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # stop serving + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl_verify_failed(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, test_utils.SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + + # no CA loaded + f_c = self.loop.create_connection(MyProto, host, port, + ssl=sslcontext_client) + with mock.patch.object(self.loop, 'call_exception_handler'): + with test_utils.disable_logger(): + with self.assertRaisesRegex(ssl.SSLError, + '(?i)certificate.verify.failed'): + self.loop.run_until_complete(f_c) + + # execute the loop to log the connection error + test_utils.run_briefly(self.loop) + + # close connection + self.assertIsNone(proto.transport) + server.close() + + @socket_helper.skip_unless_bind_unix_socket + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_unix_server_ssl_verify_failed(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, test_utils.SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # no CA loaded + f_c = self.loop.create_unix_connection(MyProto, path, + ssl=sslcontext_client, + server_hostname='invalid') + with mock.patch.object(self.loop, 'call_exception_handler'): + with test_utils.disable_logger(): + with self.assertRaisesRegex(ssl.SSLError, + '(?i)certificate.verify.failed'): + self.loop.run_until_complete(f_c) + + # execute the loop to log the connection error + test_utils.run_briefly(self.loop) + + # close connection + self.assertIsNone(proto.transport) + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl_match_failed(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, test_utils.SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations( + cafile=test_utils.SIGNING_CA) + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # incorrect server_hostname + f_c = self.loop.create_connection(MyProto, host, port, + ssl=sslcontext_client) + + # Allow for flexible libssl error messages. + regex = re.compile(r"""( + IP address mismatch, certificate is not valid for '127.0.0.1' # OpenSSL + | + CERTIFICATE_VERIFY_FAILED # AWS-LC + )""", re.X) + with mock.patch.object(self.loop, 'call_exception_handler'): + with test_utils.disable_logger(): + with self.assertRaisesRegex(ssl.CertificateError, regex): + self.loop.run_until_complete(f_c) + + # close connection + # transport is None because TLS ALERT aborted the handshake + self.assertIsNone(proto.transport) + server.close() + + @socket_helper.skip_unless_bind_unix_socket + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_unix_server_ssl_verified(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, test_utils.SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations(cafile=test_utils.SIGNING_CA) + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # Connection succeeds with correct CA and server hostname. + f_c = self.loop.create_unix_connection(MyProto, path, + ssl=sslcontext_client, + server_hostname='localhost') + client, pr = self.loop.run_until_complete(f_c) + self.loop.run_until_complete(proto.connected) + + # close connection + proto.transport.close() + client.close() + server.close() + self.loop.run_until_complete(proto.done) + + @unittest.skipIf(ssl is None, 'No ssl module') + # TODO: RUSTPYTHON + # AssertionError: {'OCSP': ('http://testca.pythontest.net/tes[629 chars]': 3} != {'subject': ((('countryName', ''),), (('loc[419 chars]'),)} + @unittest.expectedFailure + def test_create_server_ssl_verified(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, test_utils.SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations(cafile=test_utils.SIGNING_CA) + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # Connection succeeds with correct CA and server hostname. + f_c = self.loop.create_connection(MyProto, host, port, + ssl=sslcontext_client, + server_hostname='localhost') + client, pr = self.loop.run_until_complete(f_c) + self.loop.run_until_complete(proto.connected) + + # extra info is available + self.check_ssl_extra_info(client, peername=(host, port), + peercert=test_utils.PEERCERT) + + # close connection + proto.transport.close() + client.close() + server.close() + self.loop.run_until_complete(proto.done) + + def test_create_server_sock(self): + proto = self.loop.create_future() + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + proto.set_result(self) + + sock_ob = socket.create_server(('0.0.0.0', 0)) + + f = self.loop.create_server(TestMyProto, sock=sock_ob) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + self.assertEqual(sock.fileno(), sock_ob.fileno()) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + server.close() + + def test_create_server_addr_in_use(self): + sock_ob = socket.create_server(('0.0.0.0', 0)) + + f = self.loop.create_server(MyProto, sock=sock_ob) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + + f = self.loop.create_server(MyProto, host=host, port=port) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + + server.close() + + @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 not supported or enabled') + def test_create_server_dual_stack(self): + f_proto = self.loop.create_future() + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + f_proto.set_result(self) + + try_count = 0 + while True: + try: + port = socket_helper.find_unused_port() + f = self.loop.create_server(TestMyProto, host=None, port=port) + server = self.loop.run_until_complete(f) + except OSError as ex: + if ex.errno == errno.EADDRINUSE: + try_count += 1 + self.assertGreaterEqual(5, try_count) + continue + else: + raise + else: + break + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + f_proto = self.loop.create_future() + client = socket.socket(socket.AF_INET6) + client.connect(('::1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + server.close() + + @socket_helper.skip_if_tcp_blackhole + def test_server_close(self): + f = self.loop.create_server(MyProto, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + + server.close() + + client = socket.socket() + self.assertRaises( + ConnectionRefusedError, client.connect, ('127.0.0.1', port)) + client.close() + + def _test_create_datagram_endpoint(self, local_addr, family): + class TestMyDatagramProto(MyDatagramProto): + def __init__(inner_self): + super().__init__(loop=self.loop) + + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + coro = self.loop.create_datagram_endpoint( + TestMyDatagramProto, local_addr=local_addr, family=family) + s_transport, server = self.loop.run_until_complete(coro) + sockname = s_transport.get_extra_info('sockname') + host, port = socket.getnameinfo( + sockname, socket.NI_NUMERICHOST|socket.NI_NUMERICSERV) + + self.assertIsInstance(s_transport, asyncio.Transport) + self.assertIsInstance(server, TestMyDatagramProto) + self.assertEqual('INITIALIZED', server.state) + self.assertIs(server.transport, s_transport) + + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(loop=self.loop), + remote_addr=(host, port)) + transport, client = self.loop.run_until_complete(coro) + + self.assertIsInstance(transport, asyncio.Transport) + self.assertIsInstance(client, MyDatagramProto) + self.assertEqual('INITIALIZED', client.state) + self.assertIs(client.transport, transport) + + transport.sendto(b'xxx') + test_utils.run_until(self.loop, lambda: server.nbytes) + self.assertEqual(3, server.nbytes) + test_utils.run_until(self.loop, lambda: client.nbytes) + + # received + self.assertEqual(8, client.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('sockname')) + + # close connection + transport.close() + self.loop.run_until_complete(client.done) + self.assertEqual('CLOSED', client.state) + server.transport.close() + + def test_create_datagram_endpoint(self): + self._test_create_datagram_endpoint(('127.0.0.1', 0), socket.AF_INET) + + @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 not supported or enabled') + def test_create_datagram_endpoint_ipv6(self): + self._test_create_datagram_endpoint(('::1', 0), socket.AF_INET6) + + def test_create_datagram_endpoint_sock(self): + sock = None + local_address = ('127.0.0.1', 0) + infos = self.loop.run_until_complete( + self.loop.getaddrinfo( + *local_address, type=socket.SOCK_DGRAM)) + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + sock.bind(address) + except: + pass + else: + break + else: + self.fail('Can not create socket.') + + f = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(loop=self.loop), sock=sock) + tr, pr = self.loop.run_until_complete(f) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, MyDatagramProto) + tr.close() + self.loop.run_until_complete(pr.done) + + def test_datagram_send_to_non_listening_address(self): + # see: + # https://github.com/python/cpython/issues/91227 + # https://github.com/python/cpython/issues/88906 + # https://bugs.python.org/issue47071 + # https://bugs.python.org/issue44743 + # The Proactor event loop would fail to receive datagram messages after + # sending a message to an address that wasn't listening. + loop = self.loop + + class Protocol(asyncio.DatagramProtocol): + + _received_datagram = None + + def datagram_received(self, data, addr): + self._received_datagram.set_result(data) + + async def wait_for_datagram_received(self): + self._received_datagram = loop.create_future() + result = await asyncio.wait_for(self._received_datagram, 10) + self._received_datagram = None + return result + + def create_socket(): + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.setblocking(False) + sock.bind(('127.0.0.1', 0)) + return sock + + socket_1 = create_socket() + transport_1, protocol_1 = loop.run_until_complete( + loop.create_datagram_endpoint(Protocol, sock=socket_1) + ) + addr_1 = socket_1.getsockname() + + socket_2 = create_socket() + transport_2, protocol_2 = loop.run_until_complete( + loop.create_datagram_endpoint(Protocol, sock=socket_2) + ) + addr_2 = socket_2.getsockname() + + # creating and immediately closing this to try to get an address that + # is not listening + socket_3 = create_socket() + transport_3, protocol_3 = loop.run_until_complete( + loop.create_datagram_endpoint(Protocol, sock=socket_3) + ) + addr_3 = socket_3.getsockname() + transport_3.abort() + + transport_1.sendto(b'a', addr=addr_2) + self.assertEqual(loop.run_until_complete( + protocol_2.wait_for_datagram_received() + ), b'a') + + transport_2.sendto(b'b', addr=addr_1) + self.assertEqual(loop.run_until_complete( + protocol_1.wait_for_datagram_received() + ), b'b') + + # this should send to an address that isn't listening + transport_1.sendto(b'c', addr=addr_3) + loop.run_until_complete(asyncio.sleep(0)) + + # transport 1 should still be able to receive messages after sending to + # an address that wasn't listening + transport_2.sendto(b'd', addr=addr_1) + self.assertEqual(loop.run_until_complete( + protocol_1.wait_for_datagram_received() + ), b'd') + + transport_1.close() + transport_2.close() + + def test_internal_fds(self): + loop = self.create_event_loop() + if not isinstance(loop, selector_events.BaseSelectorEventLoop): + loop.close() + self.skipTest('loop is not a BaseSelectorEventLoop') + + self.assertEqual(1, loop._internal_fds) + loop.close() + self.assertEqual(0, loop._internal_fds) + self.assertIsNone(loop._csock) + self.assertIsNone(loop._ssock) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + proto = MyReadPipeProto(loop=self.loop) + + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + async def connect(): + t, p = await self.loop.connect_read_pipe( + lambda: proto, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.loop.run_until_complete(connect()) + + os.write(wpipe, b'1') + test_utils.run_until(self.loop, lambda: proto.nbytes >= 1) + self.assertEqual(1, proto.nbytes) + + os.write(wpipe, b'2345') + test_utils.run_until(self.loop, lambda: proto.nbytes >= 5) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(wpipe) + self.loop.run_until_complete(proto.done) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_unclosed_pipe_transport(self): + # This test reproduces the issue #314 on GitHub + loop = self.create_event_loop() + read_proto = MyReadPipeProto(loop=loop) + write_proto = MyWritePipeProto(loop=loop) + + rpipe, wpipe = os.pipe() + rpipeobj = io.open(rpipe, 'rb', 1024) + wpipeobj = io.open(wpipe, 'w', 1024, encoding="utf-8") + + async def connect(): + read_transport, _ = await loop.connect_read_pipe( + lambda: read_proto, rpipeobj) + write_transport, _ = await loop.connect_write_pipe( + lambda: write_proto, wpipeobj) + return read_transport, write_transport + + # Run and close the loop without closing the transports + read_transport, write_transport = loop.run_until_complete(connect()) + loop.close() + + # These 'repr' calls used to raise an AttributeError + # See Issue #314 on GitHub + self.assertIn('open', repr(read_transport)) + self.assertIn('open', repr(write_transport)) + + # Clean up (avoid ResourceWarning) + rpipeobj.close() + wpipeobj.close() + read_transport._pipe = None + write_transport._pipe = None + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + @unittest.skipUnless(hasattr(os, 'openpty'), 'need os.openpty()') + def test_read_pty_output(self): + proto = MyReadPipeProto(loop=self.loop) + + master, slave = os.openpty() + master_read_obj = io.open(master, 'rb', 0) + + async def connect(): + t, p = await self.loop.connect_read_pipe(lambda: proto, + master_read_obj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.loop.run_until_complete(connect()) + + os.write(slave, b'1') + test_utils.run_until(self.loop, lambda: proto.nbytes) + self.assertEqual(1, proto.nbytes) + + os.write(slave, b'2345') + test_utils.run_until(self.loop, lambda: proto.nbytes >= 5) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(slave) + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + proto = MyWritePipeProto(loop=self.loop) + connect = self.loop.connect_write_pipe(lambda: proto, pipeobj) + transport, p = self.loop.run_until_complete(connect) + self.assertIs(p, proto) + self.assertIs(transport, proto.transport) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + + data = bytearray() + def reader(data): + chunk = os.read(rpipe, 1024) + data += chunk + return len(data) + + test_utils.run_until(self.loop, lambda: reader(data) >= 1) + self.assertEqual(b'1', data) + + transport.write(b'2345') + test_utils.run_until(self.loop, lambda: reader(data) >= 5) + self.assertEqual(b'12345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe_disconnect_on_close(self): + rsock, wsock = socket.socketpair() + rsock.setblocking(False) + pipeobj = io.open(wsock.detach(), 'wb', 1024) + + proto = MyWritePipeProto(loop=self.loop) + connect = self.loop.connect_write_pipe(lambda: proto, pipeobj) + transport, p = self.loop.run_until_complete(connect) + self.assertIs(p, proto) + self.assertIs(transport, proto.transport) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + data = self.loop.run_until_complete(self.loop.sock_recv(rsock, 1024)) + self.assertEqual(b'1', data) + + rsock.close() + + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + @unittest.skipUnless(hasattr(os, 'openpty'), 'need os.openpty()') + # select, poll and kqueue don't support character devices (PTY) on Mac OS X + # older than 10.6 (Snow Leopard) + @support.requires_mac_ver(10, 6) + def test_write_pty(self): + master, slave = os.openpty() + slave_write_obj = io.open(slave, 'wb', 0) + + proto = MyWritePipeProto(loop=self.loop) + connect = self.loop.connect_write_pipe(lambda: proto, slave_write_obj) + transport, p = self.loop.run_until_complete(connect) + self.assertIs(p, proto) + self.assertIs(transport, proto.transport) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + + data = bytearray() + def reader(data): + chunk = os.read(master, 1024) + data += chunk + return len(data) + + test_utils.run_until(self.loop, lambda: reader(data) >= 1, + timeout=support.SHORT_TIMEOUT) + self.assertEqual(b'1', data) + + transport.write(b'2345') + test_utils.run_until(self.loop, lambda: reader(data) >= 5, + timeout=support.SHORT_TIMEOUT) + self.assertEqual(b'12345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(master) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + @unittest.skipUnless(hasattr(os, 'openpty'), 'need os.openpty()') + # select, poll and kqueue don't support character devices (PTY) on Mac OS X + # older than 10.6 (Snow Leopard) + @support.requires_mac_ver(10, 6) + def test_bidirectional_pty(self): + master, read_slave = os.openpty() + write_slave = os.dup(read_slave) + tty.setraw(read_slave) + + slave_read_obj = io.open(read_slave, 'rb', 0) + read_proto = MyReadPipeProto(loop=self.loop) + read_connect = self.loop.connect_read_pipe(lambda: read_proto, + slave_read_obj) + read_transport, p = self.loop.run_until_complete(read_connect) + self.assertIs(p, read_proto) + self.assertIs(read_transport, read_proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state) + self.assertEqual(0, read_proto.nbytes) + + + slave_write_obj = io.open(write_slave, 'wb', 0) + write_proto = MyWritePipeProto(loop=self.loop) + write_connect = self.loop.connect_write_pipe(lambda: write_proto, + slave_write_obj) + write_transport, p = self.loop.run_until_complete(write_connect) + self.assertIs(p, write_proto) + self.assertIs(write_transport, write_proto.transport) + self.assertEqual('CONNECTED', write_proto.state) + + data = bytearray() + def reader(data): + chunk = os.read(master, 1024) + data += chunk + return len(data) + + write_transport.write(b'1') + test_utils.run_until(self.loop, lambda: reader(data) >= 1, + timeout=support.SHORT_TIMEOUT) + self.assertEqual(b'1', data) + self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state) + self.assertEqual('CONNECTED', write_proto.state) + + os.write(master, b'a') + test_utils.run_until(self.loop, lambda: read_proto.nbytes >= 1, + timeout=support.SHORT_TIMEOUT) + self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state) + self.assertEqual(1, read_proto.nbytes) + self.assertEqual('CONNECTED', write_proto.state) + + write_transport.write(b'2345') + test_utils.run_until(self.loop, lambda: reader(data) >= 5, + timeout=support.SHORT_TIMEOUT) + self.assertEqual(b'12345', data) + self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state) + self.assertEqual('CONNECTED', write_proto.state) + + os.write(master, b'bcde') + test_utils.run_until(self.loop, lambda: read_proto.nbytes >= 5, + timeout=support.SHORT_TIMEOUT) + self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state) + self.assertEqual(5, read_proto.nbytes) + self.assertEqual('CONNECTED', write_proto.state) + + os.close(master) + + read_transport.close() + self.loop.run_until_complete(read_proto.done) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], read_proto.state) + + write_transport.close() + self.loop.run_until_complete(write_proto.done) + self.assertEqual('CLOSED', write_proto.state) + + def test_prompt_cancellation(self): + r, w = socket.socketpair() + r.setblocking(False) + f = self.loop.create_task(self.loop.sock_recv(r, 1)) + ov = getattr(f, 'ov', None) + if ov is not None: + self.assertTrue(ov.pending) + + async def main(): + try: + self.loop.call_soon(f.cancel) + await f + except asyncio.CancelledError: + res = 'cancelled' + else: + res = None + finally: + self.loop.stop() + return res + + t = self.loop.create_task(main()) + self.loop.run_forever() + + self.assertEqual(t.result(), 'cancelled') + self.assertRaises(asyncio.CancelledError, f.result) + if ov is not None: + self.assertFalse(ov.pending) + self.loop._stop_serving(r) + + r.close() + w.close() + + def test_timeout_rounding(self): + def _run_once(): + self.loop._run_once_counter += 1 + orig_run_once() + + orig_run_once = self.loop._run_once + self.loop._run_once_counter = 0 + self.loop._run_once = _run_once + + async def wait(): + await asyncio.sleep(1e-2) + await asyncio.sleep(1e-4) + await asyncio.sleep(1e-6) + await asyncio.sleep(1e-8) + await asyncio.sleep(1e-10) + + self.loop.run_until_complete(wait()) + # The ideal number of call is 12, but on some platforms, the selector + # may sleep at little bit less than timeout depending on the resolution + # of the clock used by the kernel. Tolerate a few useless calls on + # these platforms. + self.assertLessEqual(self.loop._run_once_counter, 20, + {'clock_resolution': self.loop._clock_resolution, + 'selector': self.loop._selector.__class__.__name__}) + + def test_remove_fds_after_closing(self): + loop = self.create_event_loop() + callback = lambda: None + r, w = socket.socketpair() + self.addCleanup(r.close) + self.addCleanup(w.close) + loop.add_reader(r, callback) + loop.add_writer(w, callback) + loop.close() + self.assertFalse(loop.remove_reader(r)) + self.assertFalse(loop.remove_writer(w)) + + def test_add_fds_after_closing(self): + loop = self.create_event_loop() + callback = lambda: None + r, w = socket.socketpair() + self.addCleanup(r.close) + self.addCleanup(w.close) + loop.close() + with self.assertRaises(RuntimeError): + loop.add_reader(r, callback) + with self.assertRaises(RuntimeError): + loop.add_writer(w, callback) + + def test_close_running_event_loop(self): + async def close_loop(loop): + self.loop.close() + + coro = close_loop(self.loop) + with self.assertRaises(RuntimeError): + self.loop.run_until_complete(coro) + + def test_close(self): + self.loop.close() + + async def test(): + pass + + func = lambda: False + coro = test() + self.addCleanup(coro.close) + + # operation blocked when the loop is closed + with self.assertRaises(RuntimeError): + self.loop.run_forever() + with self.assertRaises(RuntimeError): + fut = self.loop.create_future() + self.loop.run_until_complete(fut) + with self.assertRaises(RuntimeError): + self.loop.call_soon(func) + with self.assertRaises(RuntimeError): + self.loop.call_soon_threadsafe(func) + with self.assertRaises(RuntimeError): + self.loop.call_later(1.0, func) + with self.assertRaises(RuntimeError): + self.loop.call_at(self.loop.time() + .0, func) + with self.assertRaises(RuntimeError): + self.loop.create_task(coro) + with self.assertRaises(RuntimeError): + self.loop.add_signal_handler(signal.SIGTERM, func) + + # run_in_executor test is tricky: the method is a coroutine, + # but run_until_complete cannot be called on closed loop. + # Thus iterate once explicitly. + with self.assertRaises(RuntimeError): + it = self.loop.run_in_executor(None, func).__await__() + next(it) + + +class SubprocessTestsMixin: + + def check_terminated(self, returncode): + if sys.platform == 'win32': + self.assertIsInstance(returncode, int) + # expect 1 but sometimes get 0 + else: + self.assertEqual(-signal.SIGTERM, returncode) + + def check_killed(self, returncode): + if sys.platform == 'win32': + self.assertIsInstance(returncode, int) + # expect 1 but sometimes get 0 + else: + self.assertEqual(-signal.SIGKILL, returncode) + + @support.requires_subprocess() + def test_subprocess_exec(self): + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + with test_utils.disable_logger(): + transp.close() + self.loop.run_until_complete(proto.completed) + self.check_killed(proto.returncode) + self.assertEqual(b'Python The Winner', proto.data[1]) + + @support.requires_subprocess() + def test_subprocess_interactive(self): + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python ') + self.loop.run_until_complete(proto.got_data[1].wait()) + proto.got_data[1].clear() + self.assertEqual(b'Python ', proto.data[1]) + + stdin.write(b'The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'Python The Winner', proto.data[1]) + + with test_utils.disable_logger(): + transp.close() + self.loop.run_until_complete(proto.completed) + self.check_killed(proto.returncode) + + @support.requires_subprocess() + def test_subprocess_shell(self): + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'echo Python') + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + transp.get_pipe_transport(0).close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(0, proto.returncode) + self.assertTrue(all(f.done() for f in proto.disconnects.values())) + self.assertEqual(proto.data[1].rstrip(b'\r\n'), b'Python') + self.assertEqual(proto.data[2], b'') + transp.close() + + @support.requires_subprocess() + def test_subprocess_exitcode(self): + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + transp.close() + + @support.requires_subprocess() + def test_subprocess_close_after_finish(self): + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.assertIsNone(transp.get_pipe_transport(0)) + self.assertIsNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + self.assertIsNone(transp.close()) + + @support.requires_subprocess() + def test_subprocess_kill(self): + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + transp.kill() + self.loop.run_until_complete(proto.completed) + self.check_killed(proto.returncode) + transp.close() + + @support.requires_subprocess() + def test_subprocess_terminate(self): + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + transp.terminate() + self.loop.run_until_complete(proto.completed) + self.check_terminated(proto.returncode) + transp.close() + + @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP") + @support.requires_subprocess() + def test_subprocess_send_signal(self): + # bpo-31034: Make sure that we get the default signal handler (killing + # the process). The parent process may have decided to ignore SIGHUP, + # and signal handlers are inherited. + old_handler = signal.signal(signal.SIGHUP, signal.SIG_DFL) + try: + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + + + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + transp.send_signal(signal.SIGHUP) + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGHUP, proto.returncode) + transp.close() + finally: + signal.signal(signal.SIGHUP, old_handler) + + @support.requires_subprocess() + def test_subprocess_stderr(self): + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'test') + + self.loop.run_until_complete(proto.completed) + + transp.close() + self.assertEqual(b'OUT:test', proto.data[1]) + self.assertTrue(proto.data[2].startswith(b'ERR:test'), proto.data[2]) + self.assertEqual(0, proto.returncode) + + @support.requires_subprocess() + def test_subprocess_stderr_redirect_to_stdout(self): + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog, stderr=subprocess.STDOUT) + + + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + self.assertIsNotNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + + stdin.write(b'test') + self.loop.run_until_complete(proto.completed) + self.assertTrue(proto.data[1].startswith(b'OUT:testERR:test'), + proto.data[1]) + self.assertEqual(b'', proto.data[2]) + + transp.close() + self.assertEqual(0, proto.returncode) + + @support.requires_subprocess() + def test_subprocess_close_client_stream(self): + prog = os.path.join(os.path.dirname(__file__), 'echo3.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdout = transp.get_pipe_transport(1) + stdin.write(b'test') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'OUT:test', proto.data[1]) + + stdout.close() + self.loop.run_until_complete(proto.disconnects[1]) + stdin.write(b'xxx') + self.loop.run_until_complete(proto.got_data[2].wait()) + if sys.platform != 'win32': + self.assertEqual(b'ERR:BrokenPipeError', proto.data[2]) + else: + # After closing the read-end of a pipe, writing to the + # write-end using os.write() fails with errno==EINVAL and + # GetLastError()==ERROR_INVALID_NAME on Windows!?! (Using + # WriteFile() we get ERROR_BROKEN_PIPE as expected.) + self.assertEqual(b'ERR:OSError', proto.data[2]) + with test_utils.disable_logger(): + transp.close() + self.loop.run_until_complete(proto.completed) + self.check_killed(proto.returncode) + + @support.requires_subprocess() + def test_subprocess_wait_no_same_group(self): + # start the new process in a new session + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None, + start_new_session=True) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + transp.close() + + @support.requires_subprocess() + def test_subprocess_exec_invalid_args(self): + async def connect(**kwds): + await self.loop.subprocess_exec( + asyncio.SubprocessProtocol, + 'pwd', **kwds) + + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(universal_newlines=True)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(bufsize=4096)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(shell=True)) + + @support.requires_subprocess() + def test_subprocess_shell_invalid_args(self): + + async def connect(cmd=None, **kwds): + if not cmd: + cmd = 'pwd' + await self.loop.subprocess_shell( + asyncio.SubprocessProtocol, + cmd, **kwds) + + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(['ls', '-l'])) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(universal_newlines=True)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(bufsize=4096)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(shell=False)) + + +if sys.platform == 'win32': + + class SelectEventLoopTests(EventLoopTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, + SubprocessTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.ProactorEventLoop() + + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_remove_fds_after_closing(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") +else: + import selectors + + class UnixEventLoopTestsMixin(EventLoopTestsMixin): + def setUp(self): + super().setUp() + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + watcher = asyncio.SafeChildWatcher() + watcher.attach_loop(self.loop) + asyncio.set_child_watcher(watcher) + + def tearDown(self): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + asyncio.set_child_watcher(None) + super().tearDown() + + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(UnixEventLoopTestsMixin, + SubprocessTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop( + selectors.KqueueSelector()) + + # kqueue doesn't support character devices (PTY) on Mac OS X older + # than 10.9 (Maverick) + @support.requires_mac_ver(10, 9) + # Issue #20667: KqueueEventLoopTests.test_read_pty_output() + # hangs on OpenBSD 5.5 + @unittest.skipIf(sys.platform.startswith('openbsd'), + 'test hangs on OpenBSD') + def test_read_pty_output(self): + super().test_read_pty_output() + + # kqueue doesn't support character devices (PTY) on Mac OS X older + # than 10.9 (Maverick) + @support.requires_mac_ver(10, 9) + def test_write_pty(self): + super().test_write_pty() + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(UnixEventLoopTestsMixin, + SubprocessTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(UnixEventLoopTestsMixin, + SubprocessTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(UnixEventLoopTestsMixin, + SubprocessTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.SelectSelector()) + + +def noop(*args, **kwargs): + pass + + +class HandleTests(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = mock.Mock() + self.loop.get_debug.return_value = True + + def test_handle(self): + def callback(*args): + return args + + args = () + h = asyncio.Handle(callback, args, self.loop) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h.cancelled()) + + h.cancel() + self.assertTrue(h.cancelled()) + + def test_callback_with_exception(self): + def callback(): + raise ValueError() + + self.loop = mock.Mock() + self.loop.call_exception_handler = mock.Mock() + + h = asyncio.Handle(callback, (), self.loop) + h._run() + + self.loop.call_exception_handler.assert_called_with({ + 'message': test_utils.MockPattern('Exception in callback.*'), + 'exception': mock.ANY, + 'handle': h, + 'source_traceback': h._source_traceback, + }) + + def test_handle_weakref(self): + wd = weakref.WeakValueDictionary() + h = asyncio.Handle(lambda: None, (), self.loop) + wd['h'] = h # Would fail without __weakref__ slot. + + # TODO: RUSTPYTHON + # AssertionError: '' != '' + # - + # ? ---- + # + + @unittest.expectedFailure + def test_handle_repr(self): + self.loop.get_debug.return_value = False + + # simple function + h = asyncio.Handle(noop, (1, 2), self.loop) + filename, lineno = test_utils.get_function_source(noop) + self.assertEqual(repr(h), + '' + % (filename, lineno)) + + # cancelled handle + h.cancel() + self.assertEqual(repr(h), + '') + + # decorated function + cb = types.coroutine(noop) + h = asyncio.Handle(cb, (), self.loop) + self.assertEqual(repr(h), + '' + % (filename, lineno)) + + # partial function + cb = functools.partial(noop, 1, 2) + h = asyncio.Handle(cb, (3,), self.loop) + regex = (r'^$' + % (re.escape(filename), lineno)) + self.assertRegex(repr(h), regex) + + # partial function with keyword args + cb = functools.partial(noop, x=1) + h = asyncio.Handle(cb, (2, 3), self.loop) + regex = (r'^$' + % (re.escape(filename), lineno)) + self.assertRegex(repr(h), regex) + + # partial method + method = HandleTests.test_handle_repr + cb = functools.partialmethod(method) + filename, lineno = test_utils.get_function_source(method) + h = asyncio.Handle(cb, (), self.loop) + + cb_regex = r'' + cb_regex = fr'functools.partialmethod\({cb_regex}\)\(\)' + regex = fr'^$' + self.assertRegex(repr(h), regex) + + def test_handle_repr_debug(self): + self.loop.get_debug.return_value = True + + # simple function + create_filename = __file__ + create_lineno = sys._getframe().f_lineno + 1 + h = asyncio.Handle(noop, (1, 2), self.loop) + filename, lineno = test_utils.get_function_source(noop) + self.assertEqual(repr(h), + '' + % (filename, lineno, create_filename, create_lineno)) + + # cancelled handle + h.cancel() + self.assertEqual( + repr(h), + '' + % (filename, lineno, create_filename, create_lineno)) + + # double cancellation won't overwrite _repr + h.cancel() + self.assertEqual( + repr(h), + '' + % (filename, lineno, create_filename, create_lineno)) + + # partial function + cb = functools.partial(noop, 1, 2) + create_lineno = sys._getframe().f_lineno + 1 + h = asyncio.Handle(cb, (3,), self.loop) + regex = (r'^$' + % (re.escape(filename), lineno, + re.escape(create_filename), create_lineno)) + self.assertRegex(repr(h), regex) + + # partial function with keyword args + cb = functools.partial(noop, x=1) + create_lineno = sys._getframe().f_lineno + 1 + h = asyncio.Handle(cb, (2, 3), self.loop) + regex = (r'^$' + % (re.escape(filename), lineno, + re.escape(create_filename), create_lineno)) + self.assertRegex(repr(h), regex) + + def test_handle_source_traceback(self): + loop = asyncio.get_event_loop_policy().new_event_loop() + loop.set_debug(True) + self.set_event_loop(loop) + + def check_source_traceback(h): + lineno = sys._getframe(1).f_lineno - 1 + self.assertIsInstance(h._source_traceback, list) + self.assertEqual(h._source_traceback[-1][:3], + (__file__, + lineno, + 'test_handle_source_traceback')) + + # call_soon + h = loop.call_soon(noop) + check_source_traceback(h) + + # call_soon_threadsafe + h = loop.call_soon_threadsafe(noop) + check_source_traceback(h) + + # call_later + h = loop.call_later(0, noop) + check_source_traceback(h) + + # call_at + h = loop.call_later(0, noop) + check_source_traceback(h) + + def test_coroutine_like_object_debug_formatting(self): + # Test that asyncio can format coroutines that are instances of + # collections.abc.Coroutine, but lack cr_core or gi_code attributes + # (such as ones compiled with Cython). + + coro = CoroLike() + coro.__name__ = 'AAA' + self.assertTrue(asyncio.iscoroutine(coro)) + self.assertEqual(coroutines._format_coroutine(coro), 'AAA()') + + coro.__qualname__ = 'BBB' + self.assertEqual(coroutines._format_coroutine(coro), 'BBB()') + + coro.cr_running = True + self.assertEqual(coroutines._format_coroutine(coro), 'BBB() running') + + coro.__name__ = coro.__qualname__ = None + self.assertEqual(coroutines._format_coroutine(coro), + '() running') + + coro = CoroLike() + coro.__qualname__ = 'CoroLike' + # Some coroutines might not have '__name__', such as + # built-in async_gen.asend(). + self.assertEqual(coroutines._format_coroutine(coro), 'CoroLike()') + + coro = CoroLike() + coro.__qualname__ = 'AAA' + coro.cr_code = None + self.assertEqual(coroutines._format_coroutine(coro), 'AAA()') + + +class TimerTests(unittest.TestCase): + + def setUp(self): + super().setUp() + self.loop = mock.Mock() + + def test_hash(self): + when = time.monotonic() + h = asyncio.TimerHandle(when, lambda: False, (), + mock.Mock()) + self.assertEqual(hash(h), hash(when)) + + def test_when(self): + when = time.monotonic() + h = asyncio.TimerHandle(when, lambda: False, (), + mock.Mock()) + self.assertEqual(when, h.when()) + + def test_timer(self): + def callback(*args): + return args + + args = (1, 2, 3) + when = time.monotonic() + h = asyncio.TimerHandle(when, callback, args, mock.Mock()) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h.cancelled()) + + # cancel + h.cancel() + self.assertTrue(h.cancelled()) + self.assertIsNone(h._callback) + self.assertIsNone(h._args) + + + def test_timer_repr(self): + self.loop.get_debug.return_value = False + + # simple function + h = asyncio.TimerHandle(123, noop, (), self.loop) + src = test_utils.get_function_source(noop) + self.assertEqual(repr(h), + '' % src) + + # cancelled handle + h.cancel() + self.assertEqual(repr(h), + '') + + def test_timer_repr_debug(self): + self.loop.get_debug.return_value = True + + # simple function + create_filename = __file__ + create_lineno = sys._getframe().f_lineno + 1 + h = asyncio.TimerHandle(123, noop, (), self.loop) + filename, lineno = test_utils.get_function_source(noop) + self.assertEqual(repr(h), + '' + % (filename, lineno, create_filename, create_lineno)) + + # cancelled handle + h.cancel() + self.assertEqual(repr(h), + '' + % (filename, lineno, create_filename, create_lineno)) + + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() + + h1 = asyncio.TimerHandle(when, callback, (), self.loop) + h2 = asyncio.TimerHandle(when, callback, (), self.loop) + # TODO: Use assertLess etc. + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = asyncio.TimerHandle(when, callback, (), self.loop) + h2 = asyncio.TimerHandle(when + 10.0, callback, (), self.loop) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = asyncio.Handle(callback, (), self.loop) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + with self.assertRaises(TypeError): + h1 < () + with self.assertRaises(TypeError): + h1 > () + with self.assertRaises(TypeError): + h1 <= () + with self.assertRaises(TypeError): + h1 >= () + self.assertFalse(h1 == ()) + self.assertTrue(h1 != ()) + + self.assertTrue(h1 == ALWAYS_EQ) + self.assertFalse(h1 != ALWAYS_EQ) + self.assertTrue(h1 < LARGEST) + self.assertFalse(h1 > LARGEST) + self.assertTrue(h1 <= LARGEST) + self.assertFalse(h1 >= LARGEST) + self.assertFalse(h1 < SMALLEST) + self.assertTrue(h1 > SMALLEST) + self.assertFalse(h1 <= SMALLEST) + self.assertTrue(h1 >= SMALLEST) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_implemented(self): + f = mock.Mock() + loop = asyncio.AbstractEventLoop() + self.assertRaises( + NotImplementedError, loop.run_forever) + self.assertRaises( + NotImplementedError, loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, loop.stop) + self.assertRaises( + NotImplementedError, loop.is_running) + self.assertRaises( + NotImplementedError, loop.is_closed) + self.assertRaises( + NotImplementedError, loop.close) + self.assertRaises( + NotImplementedError, loop.create_task, None) + self.assertRaises( + NotImplementedError, loop.call_later, None, None) + self.assertRaises( + NotImplementedError, loop.call_at, f, f) + self.assertRaises( + NotImplementedError, loop.call_soon, None) + self.assertRaises( + NotImplementedError, loop.time) + self.assertRaises( + NotImplementedError, loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, loop.set_default_executor, f) + self.assertRaises( + NotImplementedError, loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.set_exception_handler, f) + self.assertRaises( + NotImplementedError, loop.default_exception_handler, f) + self.assertRaises( + NotImplementedError, loop.call_exception_handler, f) + self.assertRaises( + NotImplementedError, loop.get_debug) + self.assertRaises( + NotImplementedError, loop.set_debug, f) + + def test_not_implemented_async(self): + + async def inner(): + f = mock.Mock() + loop = asyncio.AbstractEventLoop() + + with self.assertRaises(NotImplementedError): + await loop.run_in_executor(f, f) + with self.assertRaises(NotImplementedError): + await loop.getaddrinfo('localhost', 8080) + with self.assertRaises(NotImplementedError): + await loop.getnameinfo(('localhost', 8080)) + with self.assertRaises(NotImplementedError): + await loop.create_connection(f) + with self.assertRaises(NotImplementedError): + await loop.create_server(f) + with self.assertRaises(NotImplementedError): + await loop.create_datagram_endpoint(f) + with self.assertRaises(NotImplementedError): + await loop.sock_recv(f, 10) + with self.assertRaises(NotImplementedError): + await loop.sock_recv_into(f, 10) + with self.assertRaises(NotImplementedError): + await loop.sock_sendall(f, 10) + with self.assertRaises(NotImplementedError): + await loop.sock_connect(f, f) + with self.assertRaises(NotImplementedError): + await loop.sock_accept(f) + with self.assertRaises(NotImplementedError): + await loop.sock_sendfile(f, f) + with self.assertRaises(NotImplementedError): + await loop.sendfile(f, f) + with self.assertRaises(NotImplementedError): + await loop.connect_read_pipe(f, mock.sentinel.pipe) + with self.assertRaises(NotImplementedError): + await loop.connect_write_pipe(f, mock.sentinel.pipe) + with self.assertRaises(NotImplementedError): + await loop.subprocess_shell(f, mock.sentinel) + with self.assertRaises(NotImplementedError): + await loop.subprocess_exec(f) + + loop = asyncio.new_event_loop() + loop.run_until_complete(inner()) + loop.close() + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = asyncio.AbstractEventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + self.assertRaises(NotImplementedError, policy.get_child_watcher) + self.assertRaises(NotImplementedError, policy.set_child_watcher, + object()) + + def test_get_event_loop(self): + policy = asyncio.DefaultEventLoopPolicy() + self.assertIsNone(policy._local._loop) + with self.assertWarns(DeprecationWarning) as cm: + loop = policy.get_event_loop() + self.assertEqual(cm.filename, __file__) + self.assertIsInstance(loop, asyncio.AbstractEventLoop) + + self.assertIs(policy._local._loop, loop) + self.assertIs(loop, policy.get_event_loop()) + loop.close() + + def test_get_event_loop_calls_set_event_loop(self): + policy = asyncio.DefaultEventLoopPolicy() + + with mock.patch.object( + policy, "set_event_loop", + wraps=policy.set_event_loop) as m_set_event_loop: + + with self.assertWarns(DeprecationWarning) as cm: + loop = policy.get_event_loop() + self.addCleanup(loop.close) + self.assertEqual(cm.filename, __file__) + + # policy._local._loop must be set through .set_event_loop() + # (the unix DefaultEventLoopPolicy needs this call to attach + # the child watcher correctly) + m_set_event_loop.assert_called_with(loop) + + loop.close() + + def test_get_event_loop_after_set_none(self): + policy = asyncio.DefaultEventLoopPolicy() + policy.set_event_loop(None) + self.assertRaises(RuntimeError, policy.get_event_loop) + + @mock.patch('asyncio.events.threading.current_thread') + def test_get_event_loop_thread(self, m_current_thread): + + def f(): + policy = asyncio.DefaultEventLoopPolicy() + self.assertRaises(RuntimeError, policy.get_event_loop) + + th = threading.Thread(target=f) + th.start() + th.join() + + def test_new_event_loop(self): + policy = asyncio.DefaultEventLoopPolicy() + + loop = policy.new_event_loop() + self.assertIsInstance(loop, asyncio.AbstractEventLoop) + loop.close() + + def test_set_event_loop(self): + policy = asyncio.DefaultEventLoopPolicy() + old_loop = policy.new_event_loop() + policy.set_event_loop(old_loop) + + self.assertRaises(TypeError, policy.set_event_loop, object()) + + loop = policy.new_event_loop() + policy.set_event_loop(loop) + self.assertIs(loop, policy.get_event_loop()) + self.assertIsNot(old_loop, policy.get_event_loop()) + loop.close() + old_loop.close() + + def test_get_event_loop_policy(self): + policy = asyncio.get_event_loop_policy() + self.assertIsInstance(policy, asyncio.AbstractEventLoopPolicy) + self.assertIs(policy, asyncio.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + TypeError, asyncio.set_event_loop_policy, object()) + + old_policy = asyncio.get_event_loop_policy() + + policy = asyncio.DefaultEventLoopPolicy() + asyncio.set_event_loop_policy(policy) + self.assertIs(policy, asyncio.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +class GetEventLoopTestsMixin: + + _get_running_loop_impl = None + _set_running_loop_impl = None + get_running_loop_impl = None + get_event_loop_impl = None + + def setUp(self): + self._get_running_loop_saved = events._get_running_loop + self._set_running_loop_saved = events._set_running_loop + self.get_running_loop_saved = events.get_running_loop + self.get_event_loop_saved = events.get_event_loop + + events._get_running_loop = type(self)._get_running_loop_impl + events._set_running_loop = type(self)._set_running_loop_impl + events.get_running_loop = type(self).get_running_loop_impl + events.get_event_loop = type(self).get_event_loop_impl + + asyncio._get_running_loop = type(self)._get_running_loop_impl + asyncio._set_running_loop = type(self)._set_running_loop_impl + asyncio.get_running_loop = type(self).get_running_loop_impl + asyncio.get_event_loop = type(self).get_event_loop_impl + + super().setUp() + + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + if sys.platform != 'win32': + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + watcher = asyncio.SafeChildWatcher() + watcher.attach_loop(self.loop) + asyncio.set_child_watcher(watcher) + + def tearDown(self): + try: + if sys.platform != 'win32': + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + asyncio.set_child_watcher(None) + + super().tearDown() + finally: + self.loop.close() + asyncio.set_event_loop(None) + + events._get_running_loop = self._get_running_loop_saved + events._set_running_loop = self._set_running_loop_saved + events.get_running_loop = self.get_running_loop_saved + events.get_event_loop = self.get_event_loop_saved + + asyncio._get_running_loop = self._get_running_loop_saved + asyncio._set_running_loop = self._set_running_loop_saved + asyncio.get_running_loop = self.get_running_loop_saved + asyncio.get_event_loop = self.get_event_loop_saved + + if sys.platform != 'win32': + + def test_get_event_loop_new_process(self): + # bpo-32126: The multiprocessing module used by + # ProcessPoolExecutor is not functional when the + # multiprocessing.synchronize module cannot be imported. + support.skip_if_broken_multiprocessing_synchronize() + + self.addCleanup(multiprocessing_cleanup_tests) + + async def main(): + if multiprocessing.get_start_method() == 'fork': + # Avoid 'fork' DeprecationWarning. + mp_context = multiprocessing.get_context('forkserver') + else: + mp_context = None + pool = concurrent.futures.ProcessPoolExecutor( + mp_context=mp_context) + result = await self.loop.run_in_executor( + pool, _test_get_event_loop_new_process__sub_proc) + pool.shutdown() + return result + + self.assertEqual( + self.loop.run_until_complete(main()), + 'hello') + + def test_get_running_loop_already_running(self): + async def main(): + running_loop = asyncio.get_running_loop() + with contextlib.closing(asyncio.new_event_loop()) as loop: + try: + loop.run_forever() + except RuntimeError: + pass + else: + self.fail("RuntimeError not raised") + + self.assertIs(asyncio.get_running_loop(), running_loop) + + self.loop.run_until_complete(main()) + + + def test_get_event_loop_returns_running_loop(self): + class TestError(Exception): + pass + + class Policy(asyncio.DefaultEventLoopPolicy): + def get_event_loop(self): + raise TestError + + old_policy = asyncio.get_event_loop_policy() + try: + asyncio.set_event_loop_policy(Policy()) + loop = asyncio.new_event_loop() + + with self.assertRaises(TestError): + asyncio.get_event_loop() + asyncio.set_event_loop(None) + with self.assertRaises(TestError): + asyncio.get_event_loop() + + with self.assertRaisesRegex(RuntimeError, 'no running'): + asyncio.get_running_loop() + self.assertIs(asyncio._get_running_loop(), None) + + async def func(): + self.assertIs(asyncio.get_event_loop(), loop) + self.assertIs(asyncio.get_running_loop(), loop) + self.assertIs(asyncio._get_running_loop(), loop) + + loop.run_until_complete(func()) + + asyncio.set_event_loop(loop) + with self.assertRaises(TestError): + asyncio.get_event_loop() + asyncio.set_event_loop(None) + with self.assertRaises(TestError): + asyncio.get_event_loop() + + finally: + asyncio.set_event_loop_policy(old_policy) + if loop is not None: + loop.close() + + with self.assertRaisesRegex(RuntimeError, 'no running'): + asyncio.get_running_loop() + + self.assertIs(asyncio._get_running_loop(), None) + + def test_get_event_loop_returns_running_loop2(self): + old_policy = asyncio.get_event_loop_policy() + try: + asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) + loop = asyncio.new_event_loop() + self.addCleanup(loop.close) + + with self.assertWarns(DeprecationWarning) as cm: + loop2 = asyncio.get_event_loop() + self.addCleanup(loop2.close) + self.assertEqual(cm.filename, __file__) + asyncio.set_event_loop(None) + with self.assertRaisesRegex(RuntimeError, 'no current'): + asyncio.get_event_loop() + + with self.assertRaisesRegex(RuntimeError, 'no running'): + asyncio.get_running_loop() + self.assertIs(asyncio._get_running_loop(), None) + + async def func(): + self.assertIs(asyncio.get_event_loop(), loop) + self.assertIs(asyncio.get_running_loop(), loop) + self.assertIs(asyncio._get_running_loop(), loop) + + loop.run_until_complete(func()) + + asyncio.set_event_loop(loop) + self.assertIs(asyncio.get_event_loop(), loop) + + asyncio.set_event_loop(None) + with self.assertRaisesRegex(RuntimeError, 'no current'): + asyncio.get_event_loop() + + finally: + asyncio.set_event_loop_policy(old_policy) + if loop is not None: + loop.close() + + with self.assertRaisesRegex(RuntimeError, 'no running'): + asyncio.get_running_loop() + + self.assertIs(asyncio._get_running_loop(), None) + + +class TestPyGetEventLoop(GetEventLoopTestsMixin, unittest.TestCase): + + _get_running_loop_impl = events._py__get_running_loop + _set_running_loop_impl = events._py__set_running_loop + get_running_loop_impl = events._py_get_running_loop + get_event_loop_impl = events._py_get_event_loop + + +try: + import _asyncio # NoQA +except ImportError: + pass +else: + + class TestCGetEventLoop(GetEventLoopTestsMixin, unittest.TestCase): + + _get_running_loop_impl = events._c__get_running_loop + _set_running_loop_impl = events._c__set_running_loop + get_running_loop_impl = events._c_get_running_loop + get_event_loop_impl = events._c_get_event_loop + + +class TestServer(unittest.TestCase): + + def test_get_loop(self): + loop = asyncio.new_event_loop() + self.addCleanup(loop.close) + proto = MyProto(loop) + server = loop.run_until_complete(loop.create_server(lambda: proto, '0.0.0.0', 0)) + self.assertEqual(server.get_loop(), loop) + server.close() + loop.run_until_complete(server.wait_closed()) + + +class TestAbstractServer(unittest.TestCase): + + def test_close(self): + with self.assertRaises(NotImplementedError): + events.AbstractServer().close() + + def test_wait_closed(self): + loop = asyncio.new_event_loop() + self.addCleanup(loop.close) + + with self.assertRaises(NotImplementedError): + loop.run_until_complete(events.AbstractServer().wait_closed()) + + def test_get_loop(self): + with self.assertRaises(NotImplementedError): + events.AbstractServer().get_loop() + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_futures.py b/Lib/test/test_asyncio/test_futures.py new file mode 100644 index 00000000000..162c48a4d67 --- /dev/null +++ b/Lib/test/test_asyncio/test_futures.py @@ -0,0 +1,1161 @@ +"""Tests for futures.py.""" + +import concurrent.futures +import gc +import re +import sys +import threading +import traceback +import unittest +from unittest import mock +from types import GenericAlias +import asyncio +from asyncio import futures +import warnings +from test.test_asyncio import utils as test_utils +from test import support + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +def _fakefunc(f): + return f + + +def first_cb(): + pass + + +def last_cb(): + pass + + +class ReachableCode(Exception): + """Exception to raise to indicate that some code was reached. + + Use this exception if using mocks is not a good alternative. + """ + + +class SimpleEvilEventLoop(asyncio.base_events.BaseEventLoop): + """Base class for UAF and other evil stuff requiring an evil event loop.""" + + def get_debug(self): # to suppress tracebacks + return False + + def __del__(self): + # Automatically close the evil event loop to avoid warnings. + if not self.is_closed() and not self.is_running(): + self.close() + + +class DuckFuture: + # Class that does not inherit from Future but aims to be duck-type + # compatible with it. + + _asyncio_future_blocking = False + __cancelled = False + __result = None + __exception = None + + def cancel(self): + if self.done(): + return False + self.__cancelled = True + return True + + def cancelled(self): + return self.__cancelled + + def done(self): + return (self.__cancelled + or self.__result is not None + or self.__exception is not None) + + def result(self): + self.assertFalse(self.cancelled()) + if self.__exception is not None: + raise self.__exception + return self.__result + + def exception(self): + self.assertFalse(self.cancelled()) + return self.__exception + + def set_result(self, result): + self.assertFalse(self.done()) + self.assertIsNotNone(result) + self.__result = result + + def set_exception(self, exception): + self.assertFalse(self.done()) + self.assertIsNotNone(exception) + self.__exception = exception + + def __iter__(self): + if not self.done(): + self._asyncio_future_blocking = True + yield self + self.assertTrue(self.done()) + return self.result() + + +class DuckTests(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = self.new_test_loop() + self.addCleanup(self.loop.close) + + def test_wrap_future(self): + f = DuckFuture() + g = asyncio.wrap_future(f) + self.assertIs(g, f) + + def test_ensure_future(self): + f = DuckFuture() + g = asyncio.ensure_future(f) + self.assertIs(g, f) + + +class BaseFutureTests: + + def _new_future(self, *args, **kwargs): + return self.cls(*args, **kwargs) + + def setUp(self): + super().setUp() + self.loop = self.new_test_loop() + self.addCleanup(self.loop.close) + + def test_generic_alias(self): + future = self.cls[str] + self.assertEqual(future.__args__, (str,)) + self.assertIsInstance(future, GenericAlias) + + def test_isfuture(self): + class MyFuture: + _asyncio_future_blocking = None + + def __init__(self): + self._asyncio_future_blocking = False + + self.assertFalse(asyncio.isfuture(MyFuture)) + self.assertTrue(asyncio.isfuture(MyFuture())) + self.assertFalse(asyncio.isfuture(1)) + + # As `isinstance(Mock(), Future)` returns `False` + self.assertFalse(asyncio.isfuture(mock.Mock())) + + f = self._new_future(loop=self.loop) + self.assertTrue(asyncio.isfuture(f)) + self.assertFalse(asyncio.isfuture(type(f))) + + # As `isinstance(Mock(Future), Future)` returns `True` + self.assertTrue(asyncio.isfuture(mock.Mock(type(f)))) + + f.cancel() + + def test_initial_state(self): + f = self._new_future(loop=self.loop) + self.assertFalse(f.cancelled()) + self.assertFalse(f.done()) + f.cancel() + self.assertTrue(f.cancelled()) + + def test_constructor_without_loop(self): + with self.assertRaisesRegex(RuntimeError, 'no current event loop'): + self._new_future() + + def test_constructor_use_running_loop(self): + async def test(): + return self._new_future() + f = self.loop.run_until_complete(test()) + self.assertIs(f._loop, self.loop) + self.assertIs(f.get_loop(), self.loop) + + def test_constructor_use_global_loop(self): + # Deprecated in 3.10, undeprecated in 3.12 + asyncio.set_event_loop(self.loop) + self.addCleanup(asyncio.set_event_loop, None) + f = self._new_future() + self.assertIs(f._loop, self.loop) + self.assertIs(f.get_loop(), self.loop) + + def test_constructor_positional(self): + # Make sure Future doesn't accept a positional argument + self.assertRaises(TypeError, self._new_future, 42) + + def test_uninitialized(self): + # Test that C Future doesn't crash when Future.__init__() + # call was skipped. + + fut = self.cls.__new__(self.cls, loop=self.loop) + self.assertRaises(asyncio.InvalidStateError, fut.result) + + fut = self.cls.__new__(self.cls, loop=self.loop) + self.assertRaises(asyncio.InvalidStateError, fut.exception) + + fut = self.cls.__new__(self.cls, loop=self.loop) + with self.assertRaises((RuntimeError, AttributeError)): + fut.set_result(None) + + fut = self.cls.__new__(self.cls, loop=self.loop) + with self.assertRaises((RuntimeError, AttributeError)): + fut.set_exception(Exception) + + fut = self.cls.__new__(self.cls, loop=self.loop) + with self.assertRaises((RuntimeError, AttributeError)): + fut.cancel() + + fut = self.cls.__new__(self.cls, loop=self.loop) + with self.assertRaises((RuntimeError, AttributeError)): + fut.add_done_callback(lambda f: None) + + fut = self.cls.__new__(self.cls, loop=self.loop) + with self.assertRaises((RuntimeError, AttributeError)): + fut.remove_done_callback(lambda f: None) + + fut = self.cls.__new__(self.cls, loop=self.loop) + try: + repr(fut) + except (RuntimeError, AttributeError): + pass + + fut = self.cls.__new__(self.cls, loop=self.loop) + try: + fut.__await__() + except RuntimeError: + pass + + fut = self.cls.__new__(self.cls, loop=self.loop) + try: + iter(fut) + except RuntimeError: + pass + + fut = self.cls.__new__(self.cls, loop=self.loop) + self.assertFalse(fut.cancelled()) + self.assertFalse(fut.done()) + + def test_future_cancel_message_getter(self): + f = self._new_future(loop=self.loop) + self.assertTrue(hasattr(f, '_cancel_message')) + self.assertEqual(f._cancel_message, None) + + f.cancel('my message') + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(f) + self.assertEqual(f._cancel_message, 'my message') + + def test_future_cancel_message_setter(self): + f = self._new_future(loop=self.loop) + f.cancel('my message') + f._cancel_message = 'my new message' + self.assertEqual(f._cancel_message, 'my new message') + + # Also check that the value is used for cancel(). + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(f) + self.assertEqual(f._cancel_message, 'my new message') + + def test_cancel(self): + f = self._new_future(loop=self.loop) + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(asyncio.CancelledError, f.result) + self.assertRaises(asyncio.CancelledError, f.exception) + self.assertRaises(asyncio.InvalidStateError, f.set_result, None) + self.assertRaises(asyncio.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = self._new_future(loop=self.loop) + self.assertRaises(asyncio.InvalidStateError, f.result) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(asyncio.InvalidStateError, f.set_result, None) + self.assertRaises(asyncio.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = self._new_future(loop=self.loop) + self.assertRaises(asyncio.InvalidStateError, f.exception) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(asyncio.InvalidStateError, f.set_result, None) + self.assertRaises(asyncio.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + @unittest.skip('TODO: RUSTPYTHON') + # TypeError: StopIteration interacts badly with generators and cannot be raised into a Future' + def test_stop_iteration_exception(self, stop_iteration_class=StopIteration): + exc = stop_iteration_class() + f = self._new_future(loop=self.loop) + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + exc = f.exception() + cause = exc.__cause__ + self.assertIsInstance(exc, RuntimeError) + self.assertRegex(str(exc), 'StopIteration .* cannot be raised') + self.assertIsInstance(cause, stop_iteration_class) + + def test_stop_iteration_subclass_exception(self): + class MyStopIteration(StopIteration): + pass + + self.test_stop_iteration_exception(MyStopIteration) + + def test_exception_class(self): + f = self._new_future(loop=self.loop) + f.set_exception(RuntimeError) + self.assertIsInstance(f.exception(), RuntimeError) + + def test_yield_from_twice(self): + f = self._new_future(loop=self.loop) + + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_future_repr(self): + self.loop.set_debug(True) + f_pending_debug = self._new_future(loop=self.loop) + frame = f_pending_debug._source_traceback[-1] + self.assertEqual( + repr(f_pending_debug), + f'<{self.cls.__name__} pending created at {frame[0]}:{frame[1]}>') + f_pending_debug.cancel() + + self.loop.set_debug(False) + f_pending = self._new_future(loop=self.loop) + self.assertEqual(repr(f_pending), f'<{self.cls.__name__} pending>') + f_pending.cancel() + + f_cancelled = self._new_future(loop=self.loop) + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), f'<{self.cls.__name__} cancelled>') + + f_result = self._new_future(loop=self.loop) + f_result.set_result(4) + self.assertEqual( + repr(f_result), f'<{self.cls.__name__} finished result=4>') + self.assertEqual(f_result.result(), 4) + + exc = RuntimeError() + f_exception = self._new_future(loop=self.loop) + f_exception.set_exception(exc) + self.assertEqual( + repr(f_exception), + f'<{self.cls.__name__} finished exception=RuntimeError()>') + self.assertIs(f_exception.exception(), exc) + + def func_repr(func): + filename, lineno = test_utils.get_function_source(func) + text = '%s() at %s:%s' % (func.__qualname__, filename, lineno) + return re.escape(text) + + f_one_callbacks = self._new_future(loop=self.loop) + f_one_callbacks.add_done_callback(_fakefunc) + fake_repr = func_repr(_fakefunc) + self.assertRegex( + repr(f_one_callbacks), + r'<' + self.cls.__name__ + r' pending cb=\[%s\]>' % fake_repr) + f_one_callbacks.cancel() + self.assertEqual(repr(f_one_callbacks), + f'<{self.cls.__name__} cancelled>') + + f_two_callbacks = self._new_future(loop=self.loop) + f_two_callbacks.add_done_callback(first_cb) + f_two_callbacks.add_done_callback(last_cb) + first_repr = func_repr(first_cb) + last_repr = func_repr(last_cb) + self.assertRegex(repr(f_two_callbacks), + r'<' + self.cls.__name__ + r' pending cb=\[%s, %s\]>' + % (first_repr, last_repr)) + + f_many_callbacks = self._new_future(loop=self.loop) + f_many_callbacks.add_done_callback(first_cb) + for i in range(8): + f_many_callbacks.add_done_callback(_fakefunc) + f_many_callbacks.add_done_callback(last_cb) + cb_regex = r'%s, <8 more>, %s' % (first_repr, last_repr) + self.assertRegex( + repr(f_many_callbacks), + r'<' + self.cls.__name__ + r' pending cb=\[%s\]>' % cb_regex) + f_many_callbacks.cancel() + self.assertEqual(repr(f_many_callbacks), + f'<{self.cls.__name__} cancelled>') + + # TODO: RUSTPYTHON + # self.assertEqual(newf_tb.count('raise concurrent.futures.InvalidStateError'), 1) + # AssertionError: 0 != 1 + @unittest.expectedFailure + def test_copy_state(self): + from asyncio.futures import _copy_future_state + + f = self._new_future(loop=self.loop) + f.set_result(10) + + newf = self._new_future(loop=self.loop) + _copy_future_state(f, newf) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = self._new_future(loop=self.loop) + f_exception.set_exception(RuntimeError()) + + newf_exception = self._new_future(loop=self.loop) + _copy_future_state(f_exception, newf_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = self._new_future(loop=self.loop) + f_cancelled.cancel() + + newf_cancelled = self._new_future(loop=self.loop) + _copy_future_state(f_cancelled, newf_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + try: + raise concurrent.futures.InvalidStateError + except BaseException as e: + f_exc = e + + f_conexc = self._new_future(loop=self.loop) + f_conexc.set_exception(f_exc) + + newf_conexc = self._new_future(loop=self.loop) + _copy_future_state(f_conexc, newf_conexc) + self.assertTrue(newf_conexc.done()) + try: + newf_conexc.result() + except BaseException as e: + newf_exc = e # assertRaises context manager drops the traceback + newf_tb = ''.join(traceback.format_tb(newf_exc.__traceback__)) + self.assertEqual(newf_tb.count('raise concurrent.futures.InvalidStateError'), 1) + + def test_iter(self): + fut = self._new_future(loop=self.loop) + + def coro(): + yield from fut + + def test(): + arg1, arg2 = coro() + + with self.assertRaisesRegex(RuntimeError, "await wasn't used"): + test() + fut.cancel() + + def test_log_traceback(self): + fut = self._new_future(loop=self.loop) + with self.assertRaisesRegex(ValueError, 'can only be set to False'): + fut._log_traceback = True + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_abandoned(self, m_log): + fut = self._new_future(loop=self.loop) + del fut + self.assertFalse(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_not_called_after_cancel(self, m_log): + fut = self._new_future(loop=self.loop) + fut.set_exception(Exception()) + fut.cancel() + del fut + self.assertFalse(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_result_unretrieved(self, m_log): + fut = self._new_future(loop=self.loop) + fut.set_result(42) + del fut + self.assertFalse(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_result_retrieved(self, m_log): + fut = self._new_future(loop=self.loop) + fut.set_result(42) + fut.result() + del fut + self.assertFalse(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_exception_unretrieved(self, m_log): + fut = self._new_future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + del fut + test_utils.run_briefly(self.loop) + support.gc_collect() + self.assertTrue(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_exception_retrieved(self, m_log): + fut = self._new_future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + fut.exception() + del fut + self.assertFalse(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_exception_result_retrieved(self, m_log): + fut = self._new_future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + self.assertRaises(RuntimeError, fut.result) + del fut + self.assertFalse(m_log.error.called) + + def test_wrap_future(self): + + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = asyncio.wrap_future(f1, loop=self.loop) + res, ident = self.loop.run_until_complete(f2) + self.assertTrue(asyncio.isfuture(f2)) + self.assertEqual(res, 'oi') + self.assertNotEqual(ident, threading.get_ident()) + ex.shutdown(wait=True) + + def test_wrap_future_future(self): + f1 = self._new_future(loop=self.loop) + f2 = asyncio.wrap_future(f1) + self.assertIs(f1, f2) + + def test_wrap_future_without_loop(self): + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + with self.assertRaisesRegex(RuntimeError, 'no current event loop'): + asyncio.wrap_future(f1) + ex.shutdown(wait=True) + + def test_wrap_future_use_running_loop(self): + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + async def test(): + return asyncio.wrap_future(f1) + f2 = self.loop.run_until_complete(test()) + self.assertIs(self.loop, f2._loop) + ex.shutdown(wait=True) + + def test_wrap_future_use_global_loop(self): + # Deprecated in 3.10, undeprecated in 3.12 + asyncio.set_event_loop(self.loop) + self.addCleanup(asyncio.set_event_loop, None) + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = asyncio.wrap_future(f1) + self.assertIs(self.loop, f2._loop) + ex.shutdown(wait=True) + + def test_wrap_future_cancel(self): + f1 = concurrent.futures.Future() + f2 = asyncio.wrap_future(f1, loop=self.loop) + f2.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(f1.cancelled()) + self.assertTrue(f2.cancelled()) + + def test_wrap_future_cancel2(self): + f1 = concurrent.futures.Future() + f2 = asyncio.wrap_future(f1, loop=self.loop) + f1.set_result(42) + f2.cancel() + test_utils.run_briefly(self.loop) + self.assertFalse(f1.cancelled()) + self.assertEqual(f1.result(), 42) + self.assertTrue(f2.cancelled()) + + def test_future_source_traceback(self): + self.loop.set_debug(True) + + future = self._new_future(loop=self.loop) + lineno = sys._getframe().f_lineno - 1 + self.assertIsInstance(future._source_traceback, list) + self.assertEqual(future._source_traceback[-2][:3], + (__file__, + lineno, + 'test_future_source_traceback')) + + @mock.patch('asyncio.base_events.logger') + def check_future_exception_never_retrieved(self, debug, m_log): + self.loop.set_debug(debug) + + def memory_error(): + try: + raise MemoryError() + except BaseException as exc: + return exc + exc = memory_error() + + future = self._new_future(loop=self.loop) + future.set_exception(exc) + future = None + test_utils.run_briefly(self.loop) + support.gc_collect() + + regex = f'^{self.cls.__name__} exception was never retrieved\n' + exc_info = (type(exc), exc, exc.__traceback__) + m_log.error.assert_called_once_with(mock.ANY, exc_info=exc_info) + + message = m_log.error.call_args[0][0] + self.assertRegex(message, re.compile(regex, re.DOTALL)) + + def test_future_exception_never_retrieved(self): + self.check_future_exception_never_retrieved(False) + + def test_future_exception_never_retrieved_debug(self): + self.check_future_exception_never_retrieved(True) + + def test_set_result_unless_cancelled(self): + fut = self._new_future(loop=self.loop) + fut.cancel() + futures._set_result_unless_cancelled(fut, 2) + self.assertTrue(fut.cancelled()) + + def test_future_stop_iteration_args(self): + fut = self._new_future(loop=self.loop) + fut.set_result((1, 2)) + fi = fut.__iter__() + result = None + try: + fi.send(None) + except StopIteration as ex: + result = ex.args[0] + else: + self.fail('StopIteration was expected') + self.assertEqual(result, (1, 2)) + + # TODO: RUSTPYTHON + # DeprecationWarning not triggered + @unittest.expectedFailure + def test_future_iter_throw(self): + fut = self._new_future(loop=self.loop) + fi = iter(fut) + with self.assertWarns(DeprecationWarning): + self.assertRaises(Exception, fi.throw, Exception, Exception("zebra"), None) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + self.assertRaises(TypeError, fi.throw, + Exception, Exception("elephant"), 32) + self.assertRaises(TypeError, fi.throw, + Exception("elephant"), Exception("elephant")) + # https://github.com/python/cpython/issues/101326 + self.assertRaises(ValueError, fi.throw, ValueError, None, None) + self.assertRaises(TypeError, fi.throw, list) + + def test_future_del_collect(self): + class Evil: + def __del__(self): + gc.collect() + + for i in range(100): + fut = self._new_future(loop=self.loop) + fut.set_result(Evil()) + + @unittest.skip('TODO: RUSTPYTHON') + # NotImplementedError + def test_future_cancelled_result_refcycles(self): + f = self._new_future(loop=self.loop) + f.cancel() + exc = None + try: + f.result() + except asyncio.CancelledError as e: + exc = e + self.assertIsNotNone(exc) + self.assertListEqual(gc.get_referrers(exc), []) + + @unittest.skip('TODO: RUSTPYTHON') + # NotImplementedError + def test_future_cancelled_exception_refcycles(self): + f = self._new_future(loop=self.loop) + f.cancel() + exc = None + try: + f.exception() + except asyncio.CancelledError as e: + exc = e + self.assertIsNotNone(exc) + self.assertListEqual(gc.get_referrers(exc), []) + + +@unittest.skipUnless(hasattr(futures, '_CFuture'), + 'requires the C _asyncio module') +class CFutureTests(BaseFutureTests, test_utils.TestCase): + try: + cls = futures._CFuture + except AttributeError: + cls = None + + def test_future_del_segfault(self): + fut = self._new_future(loop=self.loop) + with self.assertRaises(AttributeError): + del fut._asyncio_future_blocking + with self.assertRaises(AttributeError): + del fut._log_traceback + + def test_future_iter_get_referents_segfault(self): + # See https://github.com/python/cpython/issues/122695 + import _asyncio + it = iter(self._new_future(loop=self.loop)) + del it + evil = gc.get_referents(_asyncio) + gc.collect() + + def test_callbacks_copy(self): + # See https://github.com/python/cpython/issues/125789 + # In C implementation, the `_callbacks` attribute + # always returns a new list to avoid mutations of internal state + + fut = self._new_future(loop=self.loop) + f1 = lambda _: 1 + f2 = lambda _: 2 + fut.add_done_callback(f1) + fut.add_done_callback(f2) + callbacks = fut._callbacks + self.assertIsNot(callbacks, fut._callbacks) + fut.remove_done_callback(f1) + callbacks = fut._callbacks + self.assertIsNot(callbacks, fut._callbacks) + fut.remove_done_callback(f2) + self.assertIsNone(fut._callbacks) + + +@unittest.skipUnless(hasattr(futures, '_CFuture'), + 'requires the C _asyncio module') +class CSubFutureTests(BaseFutureTests, test_utils.TestCase): + try: + class CSubFuture(futures._CFuture): + pass + + cls = CSubFuture + except AttributeError: + cls = None + + +class PyFutureTests(BaseFutureTests, test_utils.TestCase): + cls = futures._PyFuture + + +class BaseFutureDoneCallbackTests(): + + def setUp(self): + super().setUp() + self.loop = self.new_test_loop() + + def run_briefly(self): + test_utils.run_briefly(self.loop) + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + raise NotImplementedError + + def test_callbacks_remove_first_callback(self): + bag = [] + f = self._new_future() + + cb1 = self._make_callback(bag, 42) + cb2 = self._make_callback(bag, 17) + cb3 = self._make_callback(bag, 100) + + f.add_done_callback(cb1) + f.add_done_callback(cb2) + f.add_done_callback(cb3) + + f.remove_done_callback(cb1) + f.remove_done_callback(cb1) + + self.assertEqual(bag, []) + f.set_result('foo') + + self.run_briefly() + + self.assertEqual(bag, [17, 100]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_remove_first_and_second_callback(self): + bag = [] + f = self._new_future() + + cb1 = self._make_callback(bag, 42) + cb2 = self._make_callback(bag, 17) + cb3 = self._make_callback(bag, 100) + + f.add_done_callback(cb1) + f.add_done_callback(cb2) + f.add_done_callback(cb3) + + f.remove_done_callback(cb1) + f.remove_done_callback(cb2) + f.remove_done_callback(cb1) + + self.assertEqual(bag, []) + f.set_result('foo') + + self.run_briefly() + + self.assertEqual(bag, [100]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_remove_third_callback(self): + bag = [] + f = self._new_future() + + cb1 = self._make_callback(bag, 42) + cb2 = self._make_callback(bag, 17) + cb3 = self._make_callback(bag, 100) + + f.add_done_callback(cb1) + f.add_done_callback(cb2) + f.add_done_callback(cb3) + + f.remove_done_callback(cb3) + f.remove_done_callback(cb3) + + self.assertEqual(bag, []) + f.set_result('foo') + + self.run_briefly() + + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + + self.run_briefly() + + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + + self.run_briefly() + + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + + self.run_briefly() + + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + def test_remove_done_callbacks_list_mutation(self): + # see http://bugs.python.org/issue28963 for details + + fut = self._new_future() + fut.add_done_callback(str) + + for _ in range(63): + fut.add_done_callback(id) + + class evil: + def __eq__(self, other): + fut.remove_done_callback(id) + return False + + fut.remove_done_callback(evil()) + + def test_remove_done_callbacks_list_clear(self): + # see https://github.com/python/cpython/issues/97592 for details + + fut = self._new_future() + fut.add_done_callback(str) + + for _ in range(63): + fut.add_done_callback(id) + + class evil: + def __eq__(self, other): + fut.remove_done_callback(other) + + fut.remove_done_callback(evil()) + + def test_schedule_callbacks_list_mutation_1(self): + # see http://bugs.python.org/issue28963 for details + + def mut(f): + f.remove_done_callback(str) + + fut = self._new_future() + fut.add_done_callback(mut) + fut.add_done_callback(str) + fut.add_done_callback(str) + fut.set_result(1) + test_utils.run_briefly(self.loop) + + def test_schedule_callbacks_list_mutation_2(self): + # see http://bugs.python.org/issue30828 for details + + fut = self._new_future() + fut.add_done_callback(str) + + for _ in range(63): + fut.add_done_callback(id) + + max_extra_cbs = 100 + extra_cbs = 0 + + class evil: + def __eq__(self, other): + nonlocal extra_cbs + extra_cbs += 1 + if extra_cbs < max_extra_cbs: + fut.add_done_callback(id) + return False + + fut.remove_done_callback(evil()) + + def test_evil_call_soon_list_mutation(self): + # see: https://github.com/python/cpython/issues/125969 + called_on_fut_callback0 = False + + pad = lambda: ... + + def evil_call_soon(*args, **kwargs): + nonlocal called_on_fut_callback0 + if called_on_fut_callback0: + # Called when handling fut->fut_callbacks[0] + # and mutates the length fut->fut_callbacks. + fut.remove_done_callback(int) + fut.remove_done_callback(pad) + else: + called_on_fut_callback0 = True + + fake_event_loop = SimpleEvilEventLoop() + fake_event_loop.call_soon = evil_call_soon + + with mock.patch.object(self, 'loop', fake_event_loop): + fut = self._new_future() + self.assertIs(fut.get_loop(), fake_event_loop) + + fut.add_done_callback(str) # sets fut->fut_callback0 + fut.add_done_callback(int) # sets fut->fut_callbacks[0] + fut.add_done_callback(pad) # sets fut->fut_callbacks[1] + fut.add_done_callback(pad) # sets fut->fut_callbacks[2] + fut.set_result("boom") + + # When there are no more callbacks, the Python implementation + # returns an empty list but the C implementation returns None. + self.assertIn(fut._callbacks, (None, [])) + + def test_use_after_free_on_fut_callback_0_with_evil__eq__(self): + # Special thanks to Nico-Posada for the original PoC. + # See https://github.com/python/cpython/issues/125966. + + fut = self._new_future() + + class cb_pad: + def __eq__(self, other): + return True + + class evil(cb_pad): + def __eq__(self, other): + fut.remove_done_callback(None) + return NotImplemented + + fut.add_done_callback(cb_pad()) + fut.remove_done_callback(evil()) + + def test_use_after_free_on_fut_callback_0_with_evil__getattribute__(self): + # see: https://github.com/python/cpython/issues/125984 + + class EvilEventLoop(SimpleEvilEventLoop): + def call_soon(self, *args, **kwargs): + super().call_soon(*args, **kwargs) + raise ReachableCode + + def __getattribute__(self, name): + nonlocal fut_callback_0 + if name == 'call_soon': + fut.remove_done_callback(fut_callback_0) + del fut_callback_0 + return object.__getattribute__(self, name) + + evil_loop = EvilEventLoop() + with mock.patch.object(self, 'loop', evil_loop): + fut = self._new_future() + self.assertIs(fut.get_loop(), evil_loop) + + fut_callback_0 = lambda: ... + fut.add_done_callback(fut_callback_0) + self.assertRaises(ReachableCode, fut.set_result, "boom") + + def test_use_after_free_on_fut_context_0_with_evil__getattribute__(self): + # see: https://github.com/python/cpython/issues/125984 + + class EvilEventLoop(SimpleEvilEventLoop): + def call_soon(self, *args, **kwargs): + super().call_soon(*args, **kwargs) + raise ReachableCode + + def __getattribute__(self, name): + if name == 'call_soon': + # resets the future's event loop + fut.__init__(loop=SimpleEvilEventLoop()) + return object.__getattribute__(self, name) + + evil_loop = EvilEventLoop() + with mock.patch.object(self, 'loop', evil_loop): + fut = self._new_future() + self.assertIs(fut.get_loop(), evil_loop) + + fut_callback_0 = mock.Mock() + fut_context_0 = mock.Mock() + fut.add_done_callback(fut_callback_0, context=fut_context_0) + del fut_context_0 + del fut_callback_0 + self.assertRaises(ReachableCode, fut.set_result, "boom") + + +@unittest.skipUnless(hasattr(futures, '_CFuture'), + 'requires the C _asyncio module') +class CFutureDoneCallbackTests(BaseFutureDoneCallbackTests, + test_utils.TestCase): + + def _new_future(self): + return futures._CFuture(loop=self.loop) + + +@unittest.skipUnless(hasattr(futures, '_CFuture'), + 'requires the C _asyncio module') +class CSubFutureDoneCallbackTests(BaseFutureDoneCallbackTests, + test_utils.TestCase): + + def _new_future(self): + class CSubFuture(futures._CFuture): + pass + return CSubFuture(loop=self.loop) + + +class PyFutureDoneCallbackTests(BaseFutureDoneCallbackTests, + test_utils.TestCase): + + def _new_future(self): + return futures._PyFuture(loop=self.loop) + + +class BaseFutureInheritanceTests: + + def _get_future_cls(self): + raise NotImplementedError + + def setUp(self): + super().setUp() + self.loop = self.new_test_loop() + self.addCleanup(self.loop.close) + + def test_inherit_without_calling_super_init(self): + # See https://bugs.python.org/issue38785 for the context + cls = self._get_future_cls() + + class MyFut(cls): + def __init__(self, *args, **kwargs): + # don't call super().__init__() + pass + + fut = MyFut(loop=self.loop) + with self.assertRaisesRegex( + RuntimeError, + "Future object is not initialized." + ): + fut.get_loop() + + +class PyFutureInheritanceTests(BaseFutureInheritanceTests, + test_utils.TestCase): + def _get_future_cls(self): + return futures._PyFuture + + +@unittest.skipUnless(hasattr(futures, '_CFuture'), + 'requires the C _asyncio module') +class CFutureInheritanceTests(BaseFutureInheritanceTests, + test_utils.TestCase): + def _get_future_cls(self): + return futures._CFuture + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_futures2.py b/Lib/test/test_asyncio/test_futures2.py new file mode 100644 index 00000000000..bdc3ca2eca9 --- /dev/null +++ b/Lib/test/test_asyncio/test_futures2.py @@ -0,0 +1,100 @@ +# IsolatedAsyncioTestCase based tests +import asyncio +import contextvars +import traceback +import unittest +from asyncio import tasks + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class FutureTests: + + # TODO: RUSTPYTHON + # self.assertEqual(tb.count("await future"), 1) + # AssertionError: 0 != 1 + @unittest.expectedFailure + async def test_future_traceback(self): + + async def raise_exc(): + raise TypeError(42) + + future = self.cls(raise_exc()) + + for _ in range(5): + try: + await future + except TypeError as e: + tb = ''.join(traceback.format_tb(e.__traceback__)) + self.assertEqual(tb.count("await future"), 1) + else: + self.fail('TypeError was not raised') + + async def test_task_exc_handler_correct_context(self): + # see https://github.com/python/cpython/issues/96704 + name = contextvars.ContextVar('name', default='foo') + exc_handler_called = False + + def exc_handler(*args): + self.assertEqual(name.get(), 'bar') + nonlocal exc_handler_called + exc_handler_called = True + + async def task(): + name.set('bar') + 1/0 + + loop = asyncio.get_running_loop() + loop.set_exception_handler(exc_handler) + self.cls(task()) + await asyncio.sleep(0) + self.assertTrue(exc_handler_called) + + async def test_handle_exc_handler_correct_context(self): + # see https://github.com/python/cpython/issues/96704 + name = contextvars.ContextVar('name', default='foo') + exc_handler_called = False + + def exc_handler(*args): + self.assertEqual(name.get(), 'bar') + nonlocal exc_handler_called + exc_handler_called = True + + def callback(): + name.set('bar') + 1/0 + + loop = asyncio.get_running_loop() + loop.set_exception_handler(exc_handler) + loop.call_soon(callback) + await asyncio.sleep(0) + self.assertTrue(exc_handler_called) + +# TODO: RUSTPYTHON +# @unittest.skipUnless(hasattr(tasks, '_CTask'), +# 'requires the C _asyncio module') +# class CFutureTests(FutureTests, unittest.IsolatedAsyncioTestCase): +# cls = tasks._CTask + +class PyFutureTests(FutureTests, unittest.IsolatedAsyncioTestCase): + cls = tasks._PyTask + +class FutureReprTests(unittest.IsolatedAsyncioTestCase): + + async def test_recursive_repr_for_pending_tasks(self): + # The call crashes if the guard for recursive call + # in base_futures:_future_repr_info is absent + # See Also: https://bugs.python.org/issue42183 + + async def func(): + return asyncio.all_tasks() + + # The repr() call should not raise RecursionError at first. + waiter = await asyncio.wait_for(asyncio.Task(func()),timeout=10) + self.assertIn('...', repr(waiter)) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py new file mode 100644 index 00000000000..2240a09cf8f --- /dev/null +++ b/Lib/test/test_asyncio/test_locks.py @@ -0,0 +1,1838 @@ +"""Tests for locks.py""" + +import unittest +from unittest import mock +import re + +import asyncio +import collections + +STR_RGX_REPR = ( + r'^<(?P.*?) object at (?P
.*?)' + r'\[(?P' + r'(set|unset|locked|unlocked|filling|draining|resetting|broken)' + r'(, value:\d)?' + r'(, waiters:\d+)?' + r'(, waiters:\d+\/\d+)?' # barrier + r')\]>\Z' +) +RGX_REPR = re.compile(STR_RGX_REPR) + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class LockTests(unittest.IsolatedAsyncioTestCase): + + async def test_repr(self): + lock = asyncio.Lock() + self.assertTrue(repr(lock).endswith('[unlocked]>')) + self.assertTrue(RGX_REPR.match(repr(lock))) + + await lock.acquire() + self.assertTrue(repr(lock).endswith('[locked]>')) + self.assertTrue(RGX_REPR.match(repr(lock))) + + async def test_lock(self): + lock = asyncio.Lock() + + with self.assertRaisesRegex( + TypeError, + "object Lock can't be used in 'await' expression" + ): + await lock + + self.assertFalse(lock.locked()) + + async def test_lock_doesnt_accept_loop_parameter(self): + primitives_cls = [ + asyncio.Lock, + asyncio.Condition, + asyncio.Event, + asyncio.Semaphore, + asyncio.BoundedSemaphore, + ] + + loop = asyncio.get_running_loop() + + for cls in primitives_cls: + with self.assertRaisesRegex( + TypeError, + rf"{cls.__name__}\.__init__\(\) got an unexpected " + rf"keyword argument 'loop'" + ): + cls(loop=loop) + + async def test_lock_by_with_statement(self): + primitives = [ + asyncio.Lock(), + asyncio.Condition(), + asyncio.Semaphore(), + asyncio.BoundedSemaphore(), + ] + + for lock in primitives: + await asyncio.sleep(0.01) + self.assertFalse(lock.locked()) + with self.assertRaisesRegex( + TypeError, + r"object \w+ can't be used in 'await' expression" + ): + with await lock: + pass + self.assertFalse(lock.locked()) + + async def test_acquire(self): + lock = asyncio.Lock() + result = [] + + self.assertTrue(await lock.acquire()) + + async def c1(result): + if await lock.acquire(): + result.append(1) + return True + + async def c2(result): + if await lock.acquire(): + result.append(2) + return True + + async def c3(result): + if await lock.acquire(): + result.append(3) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + lock.release() + await asyncio.sleep(0) + self.assertEqual([1], result) + + await asyncio.sleep(0) + self.assertEqual([1], result) + + t3 = asyncio.create_task(c3(result)) + + lock.release() + await asyncio.sleep(0) + self.assertEqual([1, 2], result) + + lock.release() + await asyncio.sleep(0) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + async def test_acquire_cancel(self): + lock = asyncio.Lock() + self.assertTrue(await lock.acquire()) + + task = asyncio.create_task(lock.acquire()) + asyncio.get_running_loop().call_soon(task.cancel) + with self.assertRaises(asyncio.CancelledError): + await task + self.assertFalse(lock._waiters) + + async def test_cancel_race(self): + # Several tasks: + # - A acquires the lock + # - B is blocked in acquire() + # - C is blocked in acquire() + # + # Now, concurrently: + # - B is cancelled + # - A releases the lock + # + # If B's waiter is marked cancelled but not yet removed from + # _waiters, A's release() call will crash when trying to set + # B's waiter; instead, it should move on to C's waiter. + + # Setup: A has the lock, b and c are waiting. + lock = asyncio.Lock() + + async def lockit(name, blocker): + await lock.acquire() + try: + if blocker is not None: + await blocker + finally: + lock.release() + + fa = asyncio.get_running_loop().create_future() + ta = asyncio.create_task(lockit('A', fa)) + await asyncio.sleep(0) + self.assertTrue(lock.locked()) + tb = asyncio.create_task(lockit('B', None)) + await asyncio.sleep(0) + self.assertEqual(len(lock._waiters), 1) + tc = asyncio.create_task(lockit('C', None)) + await asyncio.sleep(0) + self.assertEqual(len(lock._waiters), 2) + + # Create the race and check. + # Without the fix this failed at the last assert. + fa.set_result(None) + tb.cancel() + self.assertTrue(lock._waiters[0].cancelled()) + await asyncio.sleep(0) + self.assertFalse(lock.locked()) + self.assertTrue(ta.done()) + self.assertTrue(tb.cancelled()) + await tc + + async def test_cancel_release_race(self): + # Issue 32734 + # Acquire 4 locks, cancel second, release first + # and 2 locks are taken at once. + loop = asyncio.get_running_loop() + lock = asyncio.Lock() + lock_count = 0 + call_count = 0 + + async def lockit(): + nonlocal lock_count + nonlocal call_count + call_count += 1 + await lock.acquire() + lock_count += 1 + + def trigger(): + t1.cancel() + lock.release() + + await lock.acquire() + + t1 = asyncio.create_task(lockit()) + t2 = asyncio.create_task(lockit()) + t3 = asyncio.create_task(lockit()) + + # Start scheduled tasks + await asyncio.sleep(0) + + loop.call_soon(trigger) + with self.assertRaises(asyncio.CancelledError): + # Wait for cancellation + await t1 + + # Make sure only one lock was taken + self.assertEqual(lock_count, 1) + # While 3 calls were made to lockit() + self.assertEqual(call_count, 3) + self.assertTrue(t1.cancelled() and t2.done()) + + # Cleanup the task that is stuck on acquire. + t3.cancel() + await asyncio.sleep(0) + self.assertTrue(t3.cancelled()) + + async def test_finished_waiter_cancelled(self): + lock = asyncio.Lock() + + await lock.acquire() + self.assertTrue(lock.locked()) + + tb = asyncio.create_task(lock.acquire()) + await asyncio.sleep(0) + self.assertEqual(len(lock._waiters), 1) + + # Create a second waiter, wake up the first, and cancel it. + # Without the fix, the second was not woken up. + tc = asyncio.create_task(lock.acquire()) + tb.cancel() + lock.release() + await asyncio.sleep(0) + + self.assertTrue(lock.locked()) + self.assertTrue(tb.cancelled()) + + # Cleanup + await tc + + async def test_release_not_acquired(self): + lock = asyncio.Lock() + + self.assertRaises(RuntimeError, lock.release) + + async def test_release_no_waiters(self): + lock = asyncio.Lock() + await lock.acquire() + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + async def test_context_manager(self): + lock = asyncio.Lock() + self.assertFalse(lock.locked()) + + async with lock: + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + +class EventTests(unittest.IsolatedAsyncioTestCase): + + def test_repr(self): + ev = asyncio.Event() + self.assertTrue(repr(ev).endswith('[unset]>')) + match = RGX_REPR.match(repr(ev)) + self.assertEqual(match.group('extras'), 'unset') + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + self.assertTrue(RGX_REPR.match(repr(ev))) + + ev._waiters.append(mock.Mock()) + self.assertTrue('waiters:1' in repr(ev)) + self.assertTrue(RGX_REPR.match(repr(ev))) + + async def test_wait(self): + ev = asyncio.Event() + self.assertFalse(ev.is_set()) + + result = [] + + async def c1(result): + if await ev.wait(): + result.append(1) + + async def c2(result): + if await ev.wait(): + result.append(2) + + async def c3(result): + if await ev.wait(): + result.append(3) + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + t3 = asyncio.create_task(c3(result)) + + ev.set() + await asyncio.sleep(0) + self.assertEqual([3, 1, 2], result) + + self.assertTrue(t1.done()) + self.assertIsNone(t1.result()) + self.assertTrue(t2.done()) + self.assertIsNone(t2.result()) + self.assertTrue(t3.done()) + self.assertIsNone(t3.result()) + + async def test_wait_on_set(self): + ev = asyncio.Event() + ev.set() + + res = await ev.wait() + self.assertTrue(res) + + async def test_wait_cancel(self): + ev = asyncio.Event() + + wait = asyncio.create_task(ev.wait()) + asyncio.get_running_loop().call_soon(wait.cancel) + with self.assertRaises(asyncio.CancelledError): + await wait + self.assertFalse(ev._waiters) + + async def test_clear(self): + ev = asyncio.Event() + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + async def test_clear_with_waiters(self): + ev = asyncio.Event() + result = [] + + async def c1(result): + if await ev.wait(): + result.append(1) + return True + + t = asyncio.create_task(c1(result)) + await asyncio.sleep(0) + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + +class ConditionTests(unittest.IsolatedAsyncioTestCase): + + async def test_wait(self): + cond = asyncio.Condition() + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait(): + result.append(1) + return True + + async def c2(result): + await cond.acquire() + if await cond.wait(): + result.append(2) + return True + + async def c3(result): + await cond.acquire() + if await cond.wait(): + result.append(3) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue(await cond.acquire()) + cond.notify() + await asyncio.sleep(0) + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + async def test_wait_cancel(self): + cond = asyncio.Condition() + await cond.acquire() + + wait = asyncio.create_task(cond.wait()) + asyncio.get_running_loop().call_soon(wait.cancel) + with self.assertRaises(asyncio.CancelledError): + await wait + self.assertFalse(cond._waiters) + self.assertTrue(cond.locked()) + + async def test_wait_cancel_contested(self): + cond = asyncio.Condition() + + await cond.acquire() + self.assertTrue(cond.locked()) + + wait_task = asyncio.create_task(cond.wait()) + await asyncio.sleep(0) + self.assertFalse(cond.locked()) + + # Notify, but contest the lock before cancelling + await cond.acquire() + self.assertTrue(cond.locked()) + cond.notify() + asyncio.get_running_loop().call_soon(wait_task.cancel) + asyncio.get_running_loop().call_soon(cond.release) + + try: + await wait_task + except asyncio.CancelledError: + # Should not happen, since no cancellation points + pass + + self.assertTrue(cond.locked()) + + async def test_wait_cancel_after_notify(self): + # See bpo-32841 + waited = False + + cond = asyncio.Condition() + + async def wait_on_cond(): + nonlocal waited + async with cond: + waited = True # Make sure this area was reached + await cond.wait() + + waiter = asyncio.create_task(wait_on_cond()) + await asyncio.sleep(0) # Start waiting + + await cond.acquire() + cond.notify() + await asyncio.sleep(0) # Get to acquire() + waiter.cancel() + await asyncio.sleep(0) # Activate cancellation + cond.release() + await asyncio.sleep(0) # Cancellation should occur + + self.assertTrue(waiter.cancelled()) + self.assertTrue(waited) + + async def test_wait_unacquired(self): + cond = asyncio.Condition() + with self.assertRaises(RuntimeError): + await cond.wait() + + async def test_wait_for(self): + cond = asyncio.Condition() + presult = False + + def predicate(): + return presult + + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait_for(predicate): + result.append(1) + cond.release() + return True + + t = asyncio.create_task(c1(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([], result) + + presult = True + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + async def test_wait_for_unacquired(self): + cond = asyncio.Condition() + + # predicate can return true immediately + res = await cond.wait_for(lambda: [1, 2, 3]) + self.assertEqual([1, 2, 3], res) + + with self.assertRaises(RuntimeError): + await cond.wait_for(lambda: False) + + async def test_notify(self): + cond = asyncio.Condition() + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait(): + result.append(1) + cond.release() + return True + + async def c2(result): + await cond.acquire() + if await cond.wait(): + result.append(2) + cond.release() + return True + + async def c3(result): + await cond.acquire() + if await cond.wait(): + result.append(3) + cond.release() + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + await cond.acquire() + cond.notify(1) + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) + + await cond.acquire() + cond.notify(1) + cond.notify(2048) + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + async def test_notify_all(self): + cond = asyncio.Condition() + + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait(): + result.append(1) + cond.release() + return True + + async def c2(result): + await cond.acquire() + if await cond.wait(): + result.append(2) + cond.release() + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + await cond.acquire() + cond.notify_all() + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + def test_notify_unacquired(self): + cond = asyncio.Condition() + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = asyncio.Condition() + self.assertRaises(RuntimeError, cond.notify_all) + + async def test_repr(self): + cond = asyncio.Condition() + self.assertTrue('unlocked' in repr(cond)) + self.assertTrue(RGX_REPR.match(repr(cond))) + + await cond.acquire() + self.assertTrue('locked' in repr(cond)) + + cond._waiters.append(mock.Mock()) + self.assertTrue('waiters:1' in repr(cond)) + self.assertTrue(RGX_REPR.match(repr(cond))) + + cond._waiters.append(mock.Mock()) + self.assertTrue('waiters:2' in repr(cond)) + self.assertTrue(RGX_REPR.match(repr(cond))) + + async def test_context_manager(self): + cond = asyncio.Condition() + self.assertFalse(cond.locked()) + async with cond: + self.assertTrue(cond.locked()) + self.assertFalse(cond.locked()) + + async def test_explicit_lock(self): + async def f(lock=None, cond=None): + if lock is None: + lock = asyncio.Lock() + if cond is None: + cond = asyncio.Condition(lock) + self.assertIs(cond._lock, lock) + self.assertFalse(lock.locked()) + self.assertFalse(cond.locked()) + async with cond: + self.assertTrue(lock.locked()) + self.assertTrue(cond.locked()) + self.assertFalse(lock.locked()) + self.assertFalse(cond.locked()) + async with lock: + self.assertTrue(lock.locked()) + self.assertTrue(cond.locked()) + self.assertFalse(lock.locked()) + self.assertFalse(cond.locked()) + + # All should work in the same way. + await f() + await f(asyncio.Lock()) + lock = asyncio.Lock() + await f(lock, asyncio.Condition(lock)) + + async def test_ambiguous_loops(self): + loop = asyncio.new_event_loop() + self.addCleanup(loop.close) + + async def wrong_loop_in_lock(): + with self.assertRaises(TypeError): + asyncio.Lock(loop=loop) # actively disallowed since 3.10 + lock = asyncio.Lock() + lock._loop = loop # use private API for testing + async with lock: + # acquired immediately via the fast-path + # without interaction with any event loop. + cond = asyncio.Condition(lock) + # cond.acquire() will trigger waiting on the lock + # and it will discover the event loop mismatch. + with self.assertRaisesRegex( + RuntimeError, + "is bound to a different event loop", + ): + await cond.acquire() + + async def wrong_loop_in_cond(): + # Same analogy here with the condition's loop. + lock = asyncio.Lock() + async with lock: + with self.assertRaises(TypeError): + asyncio.Condition(lock, loop=loop) + cond = asyncio.Condition(lock) + cond._loop = loop + with self.assertRaisesRegex( + RuntimeError, + "is bound to a different event loop", + ): + await cond.wait() + + await wrong_loop_in_lock() + await wrong_loop_in_cond() + + async def test_timeout_in_block(self): + condition = asyncio.Condition() + async with condition: + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for(condition.wait(), timeout=0.5) + + async def test_cancelled_error_wakeup(self): + # Test that a cancelled error, received when awaiting wakeup, + # will be re-raised un-modified. + wake = False + raised = None + cond = asyncio.Condition() + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised + + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition, cancel it there. + task.cancel(msg="foo") + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + # TODO: RUSTPYTHON + # AssertionError: Tuples differ: () != ('foo',) + @unittest.expectedFailure + async def test_cancelled_error_re_aquire(self): + # Test that a cancelled error, received when re-aquiring lock, + # will be re-raised un-modified. + wake = False + raised = None + cond = asyncio.Condition() + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised + + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition + await cond.acquire() + wake = True + cond.notify() + await asyncio.sleep(0) + # Task is now trying to re-acquire the lock, cancel it there. + task.cancel(msg="foo") + cond.release() + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + # TODO: RUSTPYTHON + # AssertionError: 1 != 0 + @unittest.expectedFailure + async def test_cancelled_wakeup(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is awaiting initial + # wakeup on the wakeup queue. + condition = asyncio.Condition() + state = 0 + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # Cancel it while it is awaiting to be run. + # This cancellation could come from the outside + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(0.01): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + + # TODO: RUSTPYTHON + # AssertionError: 1 != 0 + @unittest.expectedFailure + async def test_cancelled_wakeup_relock(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is acquiring the lock + # again. + condition = asyncio.Condition() + state = 0 + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # now we sleep for a bit. This allows the target task to wake up and + # settle on re-aquiring the lock + await asyncio.sleep(0) + + # Cancel it while awaiting the lock + # This cancel could come the outside. + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(0.01): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + +class SemaphoreTests(unittest.IsolatedAsyncioTestCase): + + def test_initial_value_zero(self): + sem = asyncio.Semaphore(0) + self.assertTrue(sem.locked()) + + async def test_repr(self): + sem = asyncio.Semaphore() + self.assertTrue(repr(sem).endswith('[unlocked, value:1]>')) + self.assertTrue(RGX_REPR.match(repr(sem))) + + await sem.acquire() + self.assertTrue(repr(sem).endswith('[locked]>')) + self.assertTrue('waiters' not in repr(sem)) + self.assertTrue(RGX_REPR.match(repr(sem))) + + if sem._waiters is None: + sem._waiters = collections.deque() + + sem._waiters.append(mock.Mock()) + self.assertTrue('waiters:1' in repr(sem)) + self.assertTrue(RGX_REPR.match(repr(sem))) + + sem._waiters.append(mock.Mock()) + self.assertTrue('waiters:2' in repr(sem)) + self.assertTrue(RGX_REPR.match(repr(sem))) + + async def test_semaphore(self): + sem = asyncio.Semaphore() + self.assertEqual(1, sem._value) + + with self.assertRaisesRegex( + TypeError, + "object Semaphore can't be used in 'await' expression", + ): + await sem + + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, asyncio.Semaphore, -1) + + async def test_acquire(self): + sem = asyncio.Semaphore(3) + result = [] + + self.assertTrue(await sem.acquire()) + self.assertTrue(await sem.acquire()) + self.assertFalse(sem.locked()) + + async def c1(result): + await sem.acquire() + result.append(1) + return True + + async def c2(result): + await sem.acquire() + result.append(2) + return True + + async def c3(result): + await sem.acquire() + result.append(3) + return True + + async def c4(result): + await sem.acquire() + result.append(4) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + t4 = asyncio.create_task(c4(result)) + + sem.release() + sem.release() + self.assertEqual(0, sem._value) + + await asyncio.sleep(0) + self.assertEqual(0, sem._value) + self.assertEqual(3, len(result)) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + race_tasks = [t2, t3, t4] + done_tasks = [t for t in race_tasks if t.done() and t.result()] + self.assertEqual(2, len(done_tasks)) + + # cleanup locked semaphore + sem.release() + await asyncio.gather(*race_tasks) + + async def test_acquire_cancel(self): + sem = asyncio.Semaphore() + await sem.acquire() + + acquire = asyncio.create_task(sem.acquire()) + asyncio.get_running_loop().call_soon(acquire.cancel) + with self.assertRaises(asyncio.CancelledError): + await acquire + self.assertTrue((not sem._waiters) or + all(waiter.done() for waiter in sem._waiters)) + + async def test_acquire_cancel_before_awoken(self): + sem = asyncio.Semaphore(value=0) + + t1 = asyncio.create_task(sem.acquire()) + t2 = asyncio.create_task(sem.acquire()) + t3 = asyncio.create_task(sem.acquire()) + t4 = asyncio.create_task(sem.acquire()) + + await asyncio.sleep(0) + + t1.cancel() + t2.cancel() + sem.release() + + await asyncio.sleep(0) + await asyncio.sleep(0) + num_done = sum(t.done() for t in [t3, t4]) + self.assertEqual(num_done, 1) + self.assertTrue(t3.done()) + self.assertFalse(t4.done()) + + t3.cancel() + t4.cancel() + await asyncio.sleep(0) + + async def test_acquire_hang(self): + sem = asyncio.Semaphore(value=0) + + t1 = asyncio.create_task(sem.acquire()) + t2 = asyncio.create_task(sem.acquire()) + await asyncio.sleep(0) + + t1.cancel() + sem.release() + await asyncio.sleep(0) + await asyncio.sleep(0) + self.assertTrue(sem.locked()) + self.assertTrue(t2.done()) + + async def test_acquire_no_hang(self): + + sem = asyncio.Semaphore(1) + + async def c1(): + async with sem: + await asyncio.sleep(0) + t2.cancel() + + async def c2(): + async with sem: + self.assertFalse(True) + + t1 = asyncio.create_task(c1()) + t2 = asyncio.create_task(c2()) + + r1, r2 = await asyncio.gather(t1, t2, return_exceptions=True) + self.assertTrue(r1 is None) + self.assertTrue(isinstance(r2, asyncio.CancelledError)) + + await asyncio.wait_for(sem.acquire(), timeout=1.0) + + def test_release_not_acquired(self): + sem = asyncio.BoundedSemaphore() + + self.assertRaises(ValueError, sem.release) + + async def test_release_no_waiters(self): + sem = asyncio.Semaphore() + await sem.acquire() + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + async def test_acquire_fifo_order(self): + sem = asyncio.Semaphore(1) + result = [] + + async def coro(tag): + await sem.acquire() + result.append(f'{tag}_1') + await asyncio.sleep(0.01) + sem.release() + + await sem.acquire() + result.append(f'{tag}_2') + await asyncio.sleep(0.01) + sem.release() + + async with asyncio.TaskGroup() as tg: + tg.create_task(coro('c1')) + tg.create_task(coro('c2')) + tg.create_task(coro('c3')) + + self.assertEqual( + ['c1_1', 'c2_1', 'c3_1', 'c1_2', 'c2_2', 'c3_2'], + result + ) + + async def test_acquire_fifo_order_2(self): + sem = asyncio.Semaphore(1) + result = [] + + async def c1(result): + await sem.acquire() + result.append(1) + return True + + async def c2(result): + await sem.acquire() + result.append(2) + sem.release() + await sem.acquire() + result.append(4) + return True + + async def c3(result): + await sem.acquire() + result.append(3) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + + sem.release() + sem.release() + + tasks = [t1, t2, t3] + await asyncio.gather(*tasks) + self.assertEqual([1, 2, 3, 4], result) + + async def test_acquire_fifo_order_3(self): + sem = asyncio.Semaphore(0) + result = [] + + async def c1(result): + await sem.acquire() + result.append(1) + return True + + async def c2(result): + await sem.acquire() + result.append(2) + return True + + async def c3(result): + await sem.acquire() + result.append(3) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + + t1.cancel() + + await asyncio.sleep(0) + + sem.release() + sem.release() + + tasks = [t1, t2, t3] + await asyncio.gather(*tasks, return_exceptions=True) + self.assertEqual([2, 3], result) + + + # TODO: RUSTPYTHON + @unittest.skip('TODO: RUSTPYTHON') + # TypeError: An asyncio.Future, a coroutine or an awaitable is required + async def test_acquire_fifo_order_4(self): + # Test that a successfule `acquire()` will wake up multiple Tasks + # that were waiting in the Semaphore queue due to FIFO rules. + sem = asyncio.Semaphore(0) + result = [] + count = 0 + + async def c1(result): + # First task immediatlly waits for semaphore. It will be awoken by c2. + self.assertEqual(sem._value, 0) + await sem.acquire() + # We should have woken up all waiting tasks now. + self.assertEqual(sem._value, 0) + # Create a fourth task. It should run after c3, not c2. + nonlocal t4 + t4 = asyncio.create_task(c4(result)) + result.append(1) + return True + + async def c2(result): + # The second task begins by releasing semaphore three times, + # for c1, c2, and c3. + sem.release() + sem.release() + sem.release() + self.assertEqual(sem._value, 2) + # It is locked, because c1 hasn't woken up yet. + self.assertTrue(sem.locked()) + await sem.acquire() + result.append(2) + return True + + async def c3(result): + await sem.acquire() + self.assertTrue(sem.locked()) + result.append(3) + return True + + async def c4(result): + result.append(4) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + t4 = None + + await asyncio.sleep(0) + # Three tasks are in the queue, the first hasn't woken up yet. + self.assertEqual(sem._value, 2) + self.assertEqual(len(sem._waiters), 3) + await asyncio.sleep(0) + + tasks = [t1, t2, t3, t4] + await asyncio.gather(*tasks) + self.assertEqual([1, 2, 3, 4], result) + +class BarrierTests(unittest.IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + await super().asyncSetUp() + self.N = 5 + + def make_tasks(self, n, coro): + tasks = [asyncio.create_task(coro()) for _ in range(n)] + return tasks + + async def gather_tasks(self, n, coro): + tasks = self.make_tasks(n, coro) + res = await asyncio.gather(*tasks) + return res, tasks + + async def test_barrier(self): + barrier = asyncio.Barrier(self.N) + self.assertIn("filling", repr(barrier)) + with self.assertRaisesRegex( + TypeError, + "object Barrier can't be used in 'await' expression", + ): + await barrier + + self.assertIn("filling", repr(barrier)) + + async def test_repr(self): + barrier = asyncio.Barrier(self.N) + + self.assertTrue(RGX_REPR.match(repr(barrier))) + self.assertIn("filling", repr(barrier)) + + waiters = [] + async def wait(barrier): + await barrier.wait() + + incr = 2 + for i in range(incr): + waiters.append(asyncio.create_task(wait(barrier))) + await asyncio.sleep(0) + + self.assertTrue(RGX_REPR.match(repr(barrier))) + self.assertTrue(f"waiters:{incr}/{self.N}" in repr(barrier)) + self.assertIn("filling", repr(barrier)) + + # create missing waiters + for i in range(barrier.parties - barrier.n_waiting): + waiters.append(asyncio.create_task(wait(barrier))) + await asyncio.sleep(0) + + self.assertTrue(RGX_REPR.match(repr(barrier))) + self.assertIn("draining", repr(barrier)) + + # add a part of waiters + for i in range(incr): + waiters.append(asyncio.create_task(wait(barrier))) + await asyncio.sleep(0) + # and reset + await barrier.reset() + + self.assertTrue(RGX_REPR.match(repr(barrier))) + self.assertIn("resetting", repr(barrier)) + + # add a part of waiters again + for i in range(incr): + waiters.append(asyncio.create_task(wait(barrier))) + await asyncio.sleep(0) + # and abort + await barrier.abort() + + self.assertTrue(RGX_REPR.match(repr(barrier))) + self.assertIn("broken", repr(barrier)) + self.assertTrue(barrier.broken) + + # suppress unhandled exceptions + await asyncio.gather(*waiters, return_exceptions=True) + + async def test_barrier_parties(self): + self.assertRaises(ValueError, lambda: asyncio.Barrier(0)) + self.assertRaises(ValueError, lambda: asyncio.Barrier(-4)) + + self.assertIsInstance(asyncio.Barrier(self.N), asyncio.Barrier) + + async def test_context_manager(self): + self.N = 3 + barrier = asyncio.Barrier(self.N) + results = [] + + async def coro(): + async with barrier as i: + results.append(i) + + await self.gather_tasks(self.N, coro) + + self.assertListEqual(sorted(results), list(range(self.N))) + self.assertEqual(barrier.n_waiting, 0) + self.assertFalse(barrier.broken) + + async def test_filling_one_task(self): + barrier = asyncio.Barrier(1) + + async def f(): + async with barrier as i: + return True + + ret = await f() + + self.assertTrue(ret) + self.assertEqual(barrier.n_waiting, 0) + self.assertFalse(barrier.broken) + + async def test_filling_one_task_twice(self): + barrier = asyncio.Barrier(1) + + t1 = asyncio.create_task(barrier.wait()) + await asyncio.sleep(0) + self.assertEqual(barrier.n_waiting, 0) + + t2 = asyncio.create_task(barrier.wait()) + await asyncio.sleep(0) + + self.assertEqual(t1.result(), t2.result()) + self.assertEqual(t1.done(), t2.done()) + + self.assertEqual(barrier.n_waiting, 0) + self.assertFalse(barrier.broken) + + async def test_filling_task_by_task(self): + self.N = 3 + barrier = asyncio.Barrier(self.N) + + t1 = asyncio.create_task(barrier.wait()) + await asyncio.sleep(0) + self.assertEqual(barrier.n_waiting, 1) + self.assertIn("filling", repr(barrier)) + + t2 = asyncio.create_task(barrier.wait()) + await asyncio.sleep(0) + self.assertEqual(barrier.n_waiting, 2) + self.assertIn("filling", repr(barrier)) + + t3 = asyncio.create_task(barrier.wait()) + await asyncio.sleep(0) + + await asyncio.wait([t1, t2, t3]) + + self.assertEqual(barrier.n_waiting, 0) + self.assertFalse(barrier.broken) + + async def test_filling_tasks_wait_twice(self): + barrier = asyncio.Barrier(self.N) + results = [] + + async def coro(): + async with barrier: + results.append(True) + + async with barrier: + results.append(False) + + await self.gather_tasks(self.N, coro) + + self.assertEqual(len(results), self.N*2) + self.assertEqual(results.count(True), self.N) + self.assertEqual(results.count(False), self.N) + + self.assertEqual(barrier.n_waiting, 0) + self.assertFalse(barrier.broken) + + async def test_filling_tasks_check_return_value(self): + barrier = asyncio.Barrier(self.N) + results1 = [] + results2 = [] + + async def coro(): + async with barrier: + results1.append(True) + + async with barrier as i: + results2.append(True) + return i + + res, _ = await self.gather_tasks(self.N, coro) + + self.assertEqual(len(results1), self.N) + self.assertTrue(all(results1)) + self.assertEqual(len(results2), self.N) + self.assertTrue(all(results2)) + self.assertListEqual(sorted(res), list(range(self.N))) + + self.assertEqual(barrier.n_waiting, 0) + self.assertFalse(barrier.broken) + + async def test_draining_state(self): + barrier = asyncio.Barrier(self.N) + results = [] + + async def coro(): + async with barrier: + # barrier state change to filling for the last task release + results.append("draining" in repr(barrier)) + + await self.gather_tasks(self.N, coro) + + self.assertEqual(len(results), self.N) + self.assertEqual(results[-1], False) + self.assertTrue(all(results[:self.N-1])) + + self.assertEqual(barrier.n_waiting, 0) + self.assertFalse(barrier.broken) + + async def test_blocking_tasks_while_draining(self): + rewait = 2 + barrier = asyncio.Barrier(self.N) + barrier_nowaiting = asyncio.Barrier(self.N - rewait) + results = [] + rewait_n = rewait + counter = 0 + + async def coro(): + nonlocal rewait_n + + # first time waiting + await barrier.wait() + + # after wainting once for all tasks + if rewait_n > 0: + rewait_n -= 1 + # wait again only for rewait tasks + await barrier.wait() + else: + # wait for end of draining state` + await barrier_nowaiting.wait() + # wait for other waiting tasks + await barrier.wait() + + # a success means that barrier_nowaiting + # was waited for exactly N-rewait=3 times + await self.gather_tasks(self.N, coro) + + async def test_filling_tasks_cancel_one(self): + self.N = 3 + barrier = asyncio.Barrier(self.N) + results = [] + + async def coro(): + await barrier.wait() + results.append(True) + + t1 = asyncio.create_task(coro()) + await asyncio.sleep(0) + self.assertEqual(barrier.n_waiting, 1) + + t2 = asyncio.create_task(coro()) + await asyncio.sleep(0) + self.assertEqual(barrier.n_waiting, 2) + + t1.cancel() + await asyncio.sleep(0) + self.assertEqual(barrier.n_waiting, 1) + with self.assertRaises(asyncio.CancelledError): + await t1 + self.assertTrue(t1.cancelled()) + + t3 = asyncio.create_task(coro()) + await asyncio.sleep(0) + self.assertEqual(barrier.n_waiting, 2) + + t4 = asyncio.create_task(coro()) + await asyncio.gather(t2, t3, t4) + + self.assertEqual(len(results), self.N) + self.assertTrue(all(results)) + + self.assertEqual(barrier.n_waiting, 0) + self.assertFalse(barrier.broken) + + async def test_reset_barrier(self): + barrier = asyncio.Barrier(1) + + asyncio.create_task(barrier.reset()) + await asyncio.sleep(0) + + self.assertEqual(barrier.n_waiting, 0) + self.assertFalse(barrier.broken) + + async def test_reset_barrier_while_tasks_waiting(self): + barrier = asyncio.Barrier(self.N) + results = [] + + async def coro(): + try: + await barrier.wait() + except asyncio.BrokenBarrierError: + results.append(True) + + async def coro_reset(): + await barrier.reset() + + # N-1 tasks waiting on barrier with N parties + tasks = self.make_tasks(self.N-1, coro) + await asyncio.sleep(0) + + # reset the barrier + asyncio.create_task(coro_reset()) + await asyncio.gather(*tasks) + + self.assertEqual(len(results), self.N-1) + self.assertTrue(all(results)) + self.assertEqual(barrier.n_waiting, 0) + self.assertNotIn("resetting", repr(barrier)) + self.assertFalse(barrier.broken) + + async def test_reset_barrier_when_tasks_half_draining(self): + barrier = asyncio.Barrier(self.N) + results1 = [] + rest_of_tasks = self.N//2 + + async def coro(): + try: + await barrier.wait() + except asyncio.BrokenBarrierError: + # catch here waiting tasks + results1.append(True) + else: + # here drained task outside the barrier + if rest_of_tasks == barrier._count: + # tasks outside the barrier + await barrier.reset() + + await self.gather_tasks(self.N, coro) + + self.assertEqual(results1, [True]*rest_of_tasks) + self.assertEqual(barrier.n_waiting, 0) + self.assertNotIn("resetting", repr(barrier)) + self.assertFalse(barrier.broken) + + async def test_reset_barrier_when_tasks_half_draining_half_blocking(self): + barrier = asyncio.Barrier(self.N) + results1 = [] + results2 = [] + blocking_tasks = self.N//2 + count = 0 + + async def coro(): + nonlocal count + try: + await barrier.wait() + except asyncio.BrokenBarrierError: + # here catch still waiting tasks + results1.append(True) + + # so now waiting again to reach nb_parties + await barrier.wait() + else: + count += 1 + if count > blocking_tasks: + # reset now: raise asyncio.BrokenBarrierError for waiting tasks + await barrier.reset() + + # so now waiting again to reach nb_parties + await barrier.wait() + else: + try: + await barrier.wait() + except asyncio.BrokenBarrierError: + # here no catch - blocked tasks go to wait + results2.append(True) + + await self.gather_tasks(self.N, coro) + + self.assertEqual(results1, [True]*blocking_tasks) + self.assertEqual(results2, []) + self.assertEqual(barrier.n_waiting, 0) + self.assertNotIn("resetting", repr(barrier)) + self.assertFalse(barrier.broken) + + async def test_reset_barrier_while_tasks_waiting_and_waiting_again(self): + barrier = asyncio.Barrier(self.N) + results1 = [] + results2 = [] + + async def coro1(): + try: + await barrier.wait() + except asyncio.BrokenBarrierError: + results1.append(True) + finally: + await barrier.wait() + results2.append(True) + + async def coro2(): + async with barrier: + results2.append(True) + + tasks = self.make_tasks(self.N-1, coro1) + + # reset barrier, N-1 waiting tasks raise an BrokenBarrierError + asyncio.create_task(barrier.reset()) + await asyncio.sleep(0) + + # complete waiting tasks in the `finally` + asyncio.create_task(coro2()) + + await asyncio.gather(*tasks) + + self.assertFalse(barrier.broken) + self.assertEqual(len(results1), self.N-1) + self.assertTrue(all(results1)) + self.assertEqual(len(results2), self.N) + self.assertTrue(all(results2)) + + self.assertEqual(barrier.n_waiting, 0) + + + async def test_reset_barrier_while_tasks_draining(self): + barrier = asyncio.Barrier(self.N) + results1 = [] + results2 = [] + results3 = [] + count = 0 + + async def coro(): + nonlocal count + + i = await barrier.wait() + count += 1 + if count == self.N: + # last task exited from barrier + await barrier.reset() + + # wait here to reach the `parties` + await barrier.wait() + else: + try: + # second waiting + await barrier.wait() + + # N-1 tasks here + results1.append(True) + except Exception as e: + # never goes here + results2.append(True) + + # Now, pass the barrier again + # last wait, must be completed + k = await barrier.wait() + results3.append(True) + + await self.gather_tasks(self.N, coro) + + self.assertFalse(barrier.broken) + self.assertTrue(all(results1)) + self.assertEqual(len(results1), self.N-1) + self.assertEqual(len(results2), 0) + self.assertEqual(len(results3), self.N) + self.assertTrue(all(results3)) + + self.assertEqual(barrier.n_waiting, 0) + + async def test_abort_barrier(self): + barrier = asyncio.Barrier(1) + + asyncio.create_task(barrier.abort()) + await asyncio.sleep(0) + + self.assertEqual(barrier.n_waiting, 0) + self.assertTrue(barrier.broken) + + async def test_abort_barrier_when_tasks_half_draining_half_blocking(self): + barrier = asyncio.Barrier(self.N) + results1 = [] + results2 = [] + blocking_tasks = self.N//2 + count = 0 + + async def coro(): + nonlocal count + try: + await barrier.wait() + except asyncio.BrokenBarrierError: + # here catch tasks waiting to drain + results1.append(True) + else: + count += 1 + if count > blocking_tasks: + # abort now: raise asyncio.BrokenBarrierError for all tasks + await barrier.abort() + else: + try: + await barrier.wait() + except asyncio.BrokenBarrierError: + # here catch blocked tasks (already drained) + results2.append(True) + + await self.gather_tasks(self.N, coro) + + self.assertTrue(barrier.broken) + self.assertEqual(results1, [True]*blocking_tasks) + self.assertEqual(results2, [True]*(self.N-blocking_tasks-1)) + self.assertEqual(barrier.n_waiting, 0) + self.assertNotIn("resetting", repr(barrier)) + + async def test_abort_barrier_when_exception(self): + # test from threading.Barrier: see `lock_tests.test_reset` + barrier = asyncio.Barrier(self.N) + results1 = [] + results2 = [] + + async def coro(): + try: + async with barrier as i : + if i == self.N//2: + raise RuntimeError + async with barrier: + results1.append(True) + except asyncio.BrokenBarrierError: + results2.append(True) + except RuntimeError: + await barrier.abort() + + await self.gather_tasks(self.N, coro) + + self.assertTrue(barrier.broken) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertTrue(all(results2)) + self.assertEqual(barrier.n_waiting, 0) + + async def test_abort_barrier_when_exception_then_resetting(self): + # test from threading.Barrier: see `lock_tests.test_abort_and_reset`` + barrier1 = asyncio.Barrier(self.N) + barrier2 = asyncio.Barrier(self.N) + results1 = [] + results2 = [] + results3 = [] + + async def coro(): + try: + i = await barrier1.wait() + if i == self.N//2: + raise RuntimeError + await barrier1.wait() + results1.append(True) + except asyncio.BrokenBarrierError: + results2.append(True) + except RuntimeError: + await barrier1.abort() + + # Synchronize and reset the barrier. Must synchronize first so + # that everyone has left it when we reset, and after so that no + # one enters it before the reset. + i = await barrier2.wait() + if i == self.N//2: + await barrier1.reset() + await barrier2.wait() + await barrier1.wait() + results3.append(True) + + await self.gather_tasks(self.N, coro) + + self.assertFalse(barrier1.broken) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertTrue(all(results2)) + self.assertEqual(len(results3), self.N) + self.assertTrue(all(results3)) + + self.assertEqual(barrier1.n_waiting, 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_pep492.py b/Lib/test/test_asyncio/test_pep492.py new file mode 100644 index 00000000000..dc25a46985e --- /dev/null +++ b/Lib/test/test_asyncio/test_pep492.py @@ -0,0 +1,212 @@ +"""Tests support for new syntax introduced by PEP 492.""" + +import sys +import types +import unittest + +from unittest import mock + +import asyncio +from test.test_asyncio import utils as test_utils + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +# Test that asyncio.iscoroutine() uses collections.abc.Coroutine +class FakeCoro: + def send(self, value): + pass + + def throw(self, typ, val=None, tb=None): + pass + + def close(self): + pass + + def __await__(self): + yield + + +class BaseTest(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = asyncio.BaseEventLoop() + self.loop._process_events = mock.Mock() + self.loop._selector = mock.Mock() + self.loop._selector.select.return_value = () + self.set_event_loop(self.loop) + + +class LockTests(BaseTest): + + def test_context_manager_async_with(self): + primitives = [ + asyncio.Lock(), + asyncio.Condition(), + asyncio.Semaphore(), + asyncio.BoundedSemaphore(), + ] + + async def test(lock): + await asyncio.sleep(0.01) + self.assertFalse(lock.locked()) + async with lock as _lock: + self.assertIs(_lock, None) + self.assertTrue(lock.locked()) + await asyncio.sleep(0.01) + self.assertTrue(lock.locked()) + self.assertFalse(lock.locked()) + + for primitive in primitives: + self.loop.run_until_complete(test(primitive)) + self.assertFalse(primitive.locked()) + + def test_context_manager_with_await(self): + primitives = [ + asyncio.Lock(), + asyncio.Condition(), + asyncio.Semaphore(), + asyncio.BoundedSemaphore(), + ] + + async def test(lock): + await asyncio.sleep(0.01) + self.assertFalse(lock.locked()) + with self.assertRaisesRegex( + TypeError, + "can't be used in 'await' expression" + ): + with await lock: + pass + + for primitive in primitives: + self.loop.run_until_complete(test(primitive)) + self.assertFalse(primitive.locked()) + + +class StreamReaderTests(BaseTest): + + def test_readline(self): + DATA = b'line1\nline2\nline3' + + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(DATA) + stream.feed_eof() + + async def reader(): + data = [] + async for line in stream: + data.append(line) + return data + + data = self.loop.run_until_complete(reader()) + self.assertEqual(data, [b'line1\n', b'line2\n', b'line3']) + + +class CoroutineTests(BaseTest): + + def test_iscoroutine(self): + async def foo(): pass + + f = foo() + try: + self.assertTrue(asyncio.iscoroutine(f)) + finally: + f.close() # silence warning + + self.assertTrue(asyncio.iscoroutine(FakeCoro())) + + def test_iscoroutine_generator(self): + def foo(): yield + + self.assertFalse(asyncio.iscoroutine(foo())) + + + def test_iscoroutinefunction(self): + async def foo(): pass + self.assertTrue(asyncio.iscoroutinefunction(foo)) + + def test_async_def_coroutines(self): + async def bar(): + return 'spam' + async def foo(): + return await bar() + + # production mode + data = self.loop.run_until_complete(foo()) + self.assertEqual(data, 'spam') + + # debug mode + self.loop.set_debug(True) + data = self.loop.run_until_complete(foo()) + self.assertEqual(data, 'spam') + + def test_debug_mode_manages_coroutine_origin_tracking(self): + async def start(): + self.assertTrue(sys.get_coroutine_origin_tracking_depth() > 0) + + self.assertEqual(sys.get_coroutine_origin_tracking_depth(), 0) + self.loop.set_debug(True) + self.loop.run_until_complete(start()) + self.assertEqual(sys.get_coroutine_origin_tracking_depth(), 0) + + def test_types_coroutine(self): + def gen(): + yield from () + return 'spam' + + @types.coroutine + def func(): + return gen() + + async def coro(): + wrapper = func() + self.assertIsInstance(wrapper, types._GeneratorWrapper) + return await wrapper + + data = self.loop.run_until_complete(coro()) + self.assertEqual(data, 'spam') + + def test_task_print_stack(self): + T = None + + async def foo(): + f = T.get_stack(limit=1) + try: + self.assertEqual(f[0].f_code.co_name, 'foo') + finally: + f = None + + async def runner(): + nonlocal T + T = asyncio.ensure_future(foo(), loop=self.loop) + await T + + self.loop.run_until_complete(runner()) + + def test_double_await(self): + async def afunc(): + await asyncio.sleep(0.1) + + async def runner(): + coro = afunc() + t = self.loop.create_task(coro) + try: + await asyncio.sleep(0) + await coro + finally: + t.cancel() + + self.loop.set_debug(True) + with self.assertRaises( + RuntimeError, + msg='coroutine is being awaited already'): + + self.loop.run_until_complete(runner()) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py new file mode 100644 index 00000000000..65cba3b9c27 --- /dev/null +++ b/Lib/test/test_asyncio/test_proactor_events.py @@ -0,0 +1,1100 @@ +"""Tests for proactor_events.py""" + +import io +import socket +import unittest +import sys +from unittest import mock + +import asyncio +from asyncio.proactor_events import BaseProactorEventLoop +from asyncio.proactor_events import _ProactorSocketTransport +from asyncio.proactor_events import _ProactorWritePipeTransport +from asyncio.proactor_events import _ProactorDuplexPipeTransport +from asyncio.proactor_events import _ProactorDatagramTransport +from test.support import os_helper +from test.support import socket_helper +from test.test_asyncio import utils as test_utils + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +def close_transport(transport): + # Don't call transport.close() because the event loop and the IOCP proactor + # are mocked + if transport._sock is None: + return + transport._sock.close() + transport._sock = None + + +class ProactorSocketTransportTests(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = self.new_test_loop() + self.addCleanup(self.loop.close) + self.proactor = mock.Mock() + self.loop._proactor = self.proactor + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + self.sock = mock.Mock(socket.socket) + self.buffer_size = 65536 + + def socket_transport(self, waiter=None): + transport = _ProactorSocketTransport(self.loop, self.sock, + self.protocol, waiter=waiter) + self.addCleanup(close_transport, transport) + return transport + + def test_ctor(self): + fut = self.loop.create_future() + tr = self.socket_transport(waiter=fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + self.protocol.connection_made(tr) + self.proactor.recv_into.assert_called_with(self.sock, bytearray(self.buffer_size)) + + def test_loop_reading(self): + tr = self.socket_transport() + tr._loop_reading() + self.loop._proactor.recv_into.assert_called_with(self.sock, bytearray(self.buffer_size)) + self.assertFalse(self.protocol.data_received.called) + self.assertFalse(self.protocol.eof_received.called) + + def test_loop_reading_data(self): + buf = b'data' + res = self.loop.create_future() + res.set_result(len(buf)) + + tr = self.socket_transport() + tr._read_fut = res + tr._data[:len(buf)] = buf + tr._loop_reading(res) + called_buf = bytearray(self.buffer_size) + called_buf[:len(buf)] = buf + self.loop._proactor.recv_into.assert_called_with(self.sock, called_buf) + self.protocol.data_received.assert_called_with(buf) + # assert_called_with maps bytearray and bytes to the same thing so check manually + # regression test for https://github.com/python/cpython/issues/99941 + self.assertIsInstance(self.protocol.data_received.call_args.args[0], bytes) + + @unittest.skipIf(sys.flags.optimize, "Assertions are disabled in optimized mode") + def test_loop_reading_no_data(self): + res = self.loop.create_future() + res.set_result(0) + + tr = self.socket_transport() + self.assertRaises(AssertionError, tr._loop_reading, res) + + tr.close = mock.Mock() + tr._read_fut = res + tr._loop_reading(res) + self.assertFalse(self.loop._proactor.recv_into.called) + self.assertTrue(self.protocol.eof_received.called) + self.assertTrue(tr.close.called) + + def test_loop_reading_aborted(self): + err = self.loop._proactor.recv_into.side_effect = ConnectionAbortedError() + + tr = self.socket_transport() + tr._fatal_error = mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with( + err, + 'Fatal read error on pipe transport') + + def test_loop_reading_aborted_closing(self): + self.loop._proactor.recv_into.side_effect = ConnectionAbortedError() + + tr = self.socket_transport() + tr._closing = True + tr._fatal_error = mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + + def test_loop_reading_aborted_is_fatal(self): + self.loop._proactor.recv_into.side_effect = ConnectionAbortedError() + tr = self.socket_transport() + tr._closing = False + tr._fatal_error = mock.Mock() + tr._loop_reading() + self.assertTrue(tr._fatal_error.called) + + def test_loop_reading_conn_reset_lost(self): + err = self.loop._proactor.recv_into.side_effect = ConnectionResetError() + + tr = self.socket_transport() + tr._closing = False + tr._fatal_error = mock.Mock() + tr._force_close = mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + tr._force_close.assert_called_with(err) + + def test_loop_reading_exception(self): + err = self.loop._proactor.recv_into.side_effect = (OSError()) + + tr = self.socket_transport() + tr._fatal_error = mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with( + err, + 'Fatal read error on pipe transport') + + def test_write(self): + tr = self.socket_transport() + tr._loop_writing = mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, None) + tr._loop_writing.assert_called_with(data=b'data') + + def test_write_no_data(self): + tr = self.socket_transport() + tr.write(b'') + self.assertFalse(tr._buffer) + + def test_write_more(self): + tr = self.socket_transport() + tr._write_fut = mock.Mock() + tr._loop_writing = mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, b'data') + self.assertFalse(tr._loop_writing.called) + + def test_loop_writing(self): + tr = self.socket_transport() + tr._buffer = bytearray(b'data') + tr._loop_writing() + self.loop._proactor.send.assert_called_with(self.sock, b'data') + self.loop._proactor.send.return_value.add_done_callback.\ + assert_called_with(tr._loop_writing) + + @mock.patch('asyncio.proactor_events.logger') + def test_loop_writing_err(self, m_log): + err = self.loop._proactor.send.side_effect = OSError() + tr = self.socket_transport() + tr._fatal_error = mock.Mock() + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + tr._fatal_error.assert_called_with( + err, + 'Fatal write error on pipe transport') + tr._conn_lost = 1 + + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + self.assertEqual(tr._buffer, None) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_loop_writing_stop(self): + fut = self.loop.create_future() + fut.set_result(b'data') + + tr = self.socket_transport() + tr._write_fut = fut + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + + def test_loop_writing_closing(self): + fut = self.loop.create_future() + fut.set_result(1) + + tr = self.socket_transport() + tr._write_fut = fut + tr.close() + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_abort(self): + tr = self.socket_transport() + tr._force_close = mock.Mock() + tr.abort() + tr._force_close.assert_called_with(None) + + def test_close(self): + tr = self.socket_transport() + tr.close() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertTrue(tr.is_closing()) + self.assertEqual(tr._conn_lost, 1) + + self.protocol.connection_lost.reset_mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_write_fut(self): + tr = self.socket_transport() + tr._write_fut = mock.Mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_buffer(self): + tr = self.socket_transport() + tr._buffer = [b'data'] + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_invalid_sockobj(self): + tr = self.socket_transport() + self.sock.fileno.return_value = -1 + tr.close() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertFalse(self.sock.shutdown.called) + + @mock.patch('asyncio.base_events.logger') + def test_fatal_error(self, m_logging): + tr = self.socket_transport() + tr._force_close = mock.Mock() + tr._fatal_error(None) + self.assertTrue(tr._force_close.called) + self.assertTrue(m_logging.error.called) + + def test_force_close(self): + tr = self.socket_transport() + tr._buffer = [b'data'] + read_fut = tr._read_fut = mock.Mock() + write_fut = tr._write_fut = mock.Mock() + tr._force_close(None) + + read_fut.cancel.assert_called_with() + write_fut.cancel.assert_called_with() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual(None, tr._buffer) + self.assertEqual(tr._conn_lost, 1) + + def test_loop_writing_force_close(self): + exc_handler = mock.Mock() + self.loop.set_exception_handler(exc_handler) + fut = self.loop.create_future() + fut.set_result(1) + self.proactor.send.return_value = fut + + tr = self.socket_transport() + tr.write(b'data') + tr._force_close(None) + test_utils.run_briefly(self.loop) + exc_handler.assert_not_called() + + def test_force_close_idempotent(self): + tr = self.socket_transport() + tr._closing = True + tr._force_close(None) + test_utils.run_briefly(self.loop) + # See https://github.com/python/cpython/issues/89237 + # `protocol.connection_lost` should be called even if + # the transport was closed forcefully otherwise + # the resources held by protocol will never be freed + # and waiters will never be notified leading to hang. + self.assertTrue(self.protocol.connection_lost.called) + + def test_force_close_protocol_connection_lost_once(self): + tr = self.socket_transport() + self.assertFalse(self.protocol.connection_lost.called) + tr._closing = True + # Calling _force_close twice should not call + # protocol.connection_lost twice + tr._force_close(None) + tr._force_close(None) + test_utils.run_briefly(self.loop) + self.assertEqual(1, self.protocol.connection_lost.call_count) + + def test_close_protocol_connection_lost_once(self): + tr = self.socket_transport() + self.assertFalse(self.protocol.connection_lost.called) + # Calling close twice should not call + # protocol.connection_lost twice + tr.close() + tr.close() + test_utils.run_briefly(self.loop) + self.assertEqual(1, self.protocol.connection_lost.call_count) + + def test_fatal_error_2(self): + tr = self.socket_transport() + tr._buffer = [b'data'] + tr._force_close(None) + + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual(None, tr._buffer) + + def test_call_connection_lost(self): + tr = self.socket_transport() + tr._call_connection_lost(None) + self.assertTrue(self.protocol.connection_lost.called) + self.assertTrue(self.sock.close.called) + + def test_write_eof(self): + tr = self.socket_transport() + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.write_eof() + self.assertEqual(self.sock.shutdown.call_count, 1) + tr.close() + + def test_write_eof_buffer(self): + tr = self.socket_transport() + f = self.loop.create_future() + tr._loop._proactor.send.return_value = f + tr.write(b'data') + tr.write_eof() + self.assertTrue(tr._eof_written) + self.assertFalse(self.sock.shutdown.called) + tr._loop._proactor.send.assert_called_with(self.sock, b'data') + f.set_result(4) + self.loop._run_once() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.close() + + def test_write_eof_write_pipe(self): + tr = _ProactorWritePipeTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.assertTrue(tr.is_closing()) + self.loop._run_once() + self.assertTrue(self.sock.close.called) + tr.close() + + def test_write_eof_buffer_write_pipe(self): + tr = _ProactorWritePipeTransport(self.loop, self.sock, self.protocol) + f = self.loop.create_future() + tr._loop._proactor.send.return_value = f + tr.write(b'data') + tr.write_eof() + self.assertTrue(tr.is_closing()) + self.assertFalse(self.sock.shutdown.called) + tr._loop._proactor.send.assert_called_with(self.sock, b'data') + f.set_result(4) + self.loop._run_once() + self.loop._run_once() + self.assertTrue(self.sock.close.called) + tr.close() + + def test_write_eof_duplex_pipe(self): + tr = _ProactorDuplexPipeTransport( + self.loop, self.sock, self.protocol) + self.assertFalse(tr.can_write_eof()) + with self.assertRaises(NotImplementedError): + tr.write_eof() + close_transport(tr) + + def test_pause_resume_reading(self): + tr = self.socket_transport() + index = 0 + msgs = [b'data1', b'data2', b'data3', b'data4', b'data5', b''] + reversed_msgs = list(reversed(msgs)) + + def recv_into(sock, data): + f = self.loop.create_future() + msg = reversed_msgs.pop() + + result = f.result + def monkey(): + data[:len(msg)] = msg + return result() + f.result = monkey + + f.set_result(len(msg)) + return f + + self.loop._proactor.recv_into.side_effect = recv_into + self.loop._run_once() + self.assertFalse(tr._paused) + self.assertTrue(tr.is_reading()) + + for msg in msgs[:2]: + self.loop._run_once() + self.protocol.data_received.assert_called_with(bytearray(msg)) + + tr.pause_reading() + tr.pause_reading() + self.assertTrue(tr._paused) + self.assertFalse(tr.is_reading()) + for i in range(10): + self.loop._run_once() + self.protocol.data_received.assert_called_with(bytearray(msgs[1])) + + tr.resume_reading() + tr.resume_reading() + self.assertFalse(tr._paused) + self.assertTrue(tr.is_reading()) + + for msg in msgs[2:4]: + self.loop._run_once() + self.protocol.data_received.assert_called_with(bytearray(msg)) + + tr.pause_reading() + tr.resume_reading() + self.loop.call_exception_handler = mock.Mock() + self.loop._run_once() + self.loop.call_exception_handler.assert_not_called() + self.protocol.data_received.assert_called_with(bytearray(msgs[4])) + tr.close() + + self.assertFalse(tr.is_reading()) + + def test_pause_reading_connection_made(self): + tr = self.socket_transport() + self.protocol.connection_made.side_effect = lambda _: tr.pause_reading() + test_utils.run_briefly(self.loop) + self.assertFalse(tr.is_reading()) + self.loop.assert_no_reader(7) + + tr.resume_reading() + self.assertTrue(tr.is_reading()) + + tr.close() + self.assertFalse(tr.is_reading()) + + + def pause_writing_transport(self, high): + tr = self.socket_transport() + tr.set_write_buffer_limits(high=high) + + self.assertEqual(tr.get_write_buffer_size(), 0) + self.assertFalse(self.protocol.pause_writing.called) + self.assertFalse(self.protocol.resume_writing.called) + return tr + + def test_pause_resume_writing(self): + tr = self.pause_writing_transport(high=4) + + # write a large chunk, must pause writing + fut = self.loop.create_future() + self.loop._proactor.send.return_value = fut + tr.write(b'large data') + self.loop._run_once() + self.assertTrue(self.protocol.pause_writing.called) + + # flush the buffer + fut.set_result(None) + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 0) + self.assertTrue(self.protocol.resume_writing.called) + + def test_pause_writing_2write(self): + tr = self.pause_writing_transport(high=4) + + # first short write, the buffer is not full (3 <= 4) + fut1 = self.loop.create_future() + self.loop._proactor.send.return_value = fut1 + tr.write(b'123') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 3) + self.assertFalse(self.protocol.pause_writing.called) + + # fill the buffer, must pause writing (6 > 4) + tr.write(b'abc') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 6) + self.assertTrue(self.protocol.pause_writing.called) + + def test_pause_writing_3write(self): + tr = self.pause_writing_transport(high=4) + + # first short write, the buffer is not full (1 <= 4) + fut = self.loop.create_future() + self.loop._proactor.send.return_value = fut + tr.write(b'1') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 1) + self.assertFalse(self.protocol.pause_writing.called) + + # second short write, the buffer is not full (3 <= 4) + tr.write(b'23') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 3) + self.assertFalse(self.protocol.pause_writing.called) + + # fill the buffer, must pause writing (6 > 4) + tr.write(b'abc') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 6) + self.assertTrue(self.protocol.pause_writing.called) + + def test_dont_pause_writing(self): + tr = self.pause_writing_transport(high=4) + + # write a large chunk which completes immediately, + # it should not pause writing + fut = self.loop.create_future() + fut.set_result(None) + self.loop._proactor.send.return_value = fut + tr.write(b'very large data') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 0) + self.assertFalse(self.protocol.pause_writing.called) + + +class ProactorDatagramTransportTests(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = self.new_test_loop() + self.proactor = mock.Mock() + self.loop._proactor = self.proactor + self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol) + self.sock = mock.Mock(spec_set=socket.socket) + self.sock.fileno.return_value = 7 + + def datagram_transport(self, address=None): + self.sock.getpeername.side_effect = None if address else OSError + transport = _ProactorDatagramTransport(self.loop, self.sock, + self.protocol, + address=address) + self.addCleanup(close_transport, transport) + return transport + + def test_sendto(self): + data = b'data' + transport = self.datagram_transport() + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.proactor.sendto.called) + self.proactor.sendto.assert_called_with( + self.sock, data, addr=('0.0.0.0', 1234)) + self.assertFalse(transport._buffer) + self.assertEqual(0, transport._buffer_size) + + def test_sendto_bytearray(self): + data = bytearray(b'data') + transport = self.datagram_transport() + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.proactor.sendto.called) + self.proactor.sendto.assert_called_with( + self.sock, b'data', addr=('0.0.0.0', 1234)) + + def test_sendto_memoryview(self): + data = memoryview(b'data') + transport = self.datagram_transport() + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.proactor.sendto.called) + self.proactor.sendto.assert_called_with( + self.sock, b'data', addr=('0.0.0.0', 1234)) + + # TODO: RUSTPYTHON + # AssertionError: False is not true + @unittest.expectedFailure + def test_sendto_no_data(self): + transport = self.datagram_transport() + transport.sendto(b'', ('0.0.0.0', 1234)) + self.assertTrue(self.proactor.sendto.called) + self.proactor.sendto.assert_called_with( + self.sock, b'', addr=('0.0.0.0', 1234)) + + def test_sendto_buffer(self): + transport = self.datagram_transport() + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport._write_fut = object() + transport.sendto(b'data2', ('0.0.0.0', 12345)) + self.assertFalse(self.proactor.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + + def test_sendto_buffer_bytearray(self): + data2 = bytearray(b'data2') + transport = self.datagram_transport() + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport._write_fut = object() + transport.sendto(data2, ('0.0.0.0', 12345)) + self.assertFalse(self.proactor.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + self.assertIsInstance(transport._buffer[1][0], bytes) + + def test_sendto_buffer_memoryview(self): + data2 = memoryview(b'data2') + transport = self.datagram_transport() + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport._write_fut = object() + transport.sendto(data2, ('0.0.0.0', 12345)) + self.assertFalse(self.proactor.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + self.assertIsInstance(transport._buffer[1][0], bytes) + + # TODO: RUSTPYTHON + # AssertionError: Lists differ: [(b'data1', ('0.0.0.0', 12345)), (b'', ('0.0.0.0', 12345))] != [(b'data1', ('0.0.0.0', 12345))] + @unittest.expectedFailure + def test_sendto_buffer_nodata(self): + data2 = b'' + transport = self.datagram_transport() + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport._write_fut = object() + transport.sendto(data2, ('0.0.0.0', 12345)) + self.assertFalse(self.proactor.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'', ('0.0.0.0', 12345))], + list(transport._buffer)) + self.assertIsInstance(transport._buffer[1][0], bytes) + + @mock.patch('asyncio.proactor_events.logger') + def test_sendto_exception(self, m_log): + data = b'data' + err = self.proactor.sendto.side_effect = RuntimeError() + + transport = self.datagram_transport() + transport._fatal_error = mock.Mock() + transport.sendto(data, ()) + + self.assertTrue(transport._fatal_error.called) + transport._fatal_error.assert_called_with( + err, + 'Fatal write error on datagram transport') + transport._conn_lost = 1 + + transport._address = ('123',) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + m_log.warning.assert_called_with('socket.sendto() raised exception.') + + def test_sendto_error_received(self): + data = b'data' + + self.sock.sendto.side_effect = ConnectionRefusedError + + transport = self.datagram_transport() + transport._fatal_error = mock.Mock() + transport.sendto(data, ()) + + self.assertEqual(transport._conn_lost, 0) + self.assertFalse(transport._fatal_error.called) + + def test_sendto_error_received_connected(self): + data = b'data' + + self.proactor.send.side_effect = ConnectionRefusedError + + transport = self.datagram_transport(address=('0.0.0.0', 1)) + transport._fatal_error = mock.Mock() + transport.sendto(data) + + self.assertFalse(transport._fatal_error.called) + self.assertTrue(self.protocol.error_received.called) + + def test_sendto_str(self): + transport = self.datagram_transport() + self.assertRaises(TypeError, transport.sendto, 'str', ()) + + def test_sendto_connected_addr(self): + transport = self.datagram_transport(address=('0.0.0.0', 1)) + self.assertRaises( + ValueError, transport.sendto, b'str', ('0.0.0.0', 2)) + + def test_sendto_closing(self): + transport = self.datagram_transport(address=(1,)) + transport.close() + self.assertEqual(transport._conn_lost, 1) + transport.sendto(b'data', (1,)) + self.assertEqual(transport._conn_lost, 2) + + def test__loop_writing_closing(self): + transport = self.datagram_transport() + transport._closing = True + transport._loop_writing() + self.assertIsNone(transport._write_fut) + test_utils.run_briefly(self.loop) + self.sock.close.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + def test__loop_writing_exception(self): + err = self.proactor.sendto.side_effect = RuntimeError() + + transport = self.datagram_transport() + transport._fatal_error = mock.Mock() + transport._buffer.append((b'data', ())) + transport._loop_writing() + + transport._fatal_error.assert_called_with( + err, + 'Fatal write error on datagram transport') + + def test__loop_writing_error_received(self): + self.proactor.sendto.side_effect = ConnectionRefusedError + + transport = self.datagram_transport() + transport._fatal_error = mock.Mock() + transport._buffer.append((b'data', ())) + transport._loop_writing() + + self.assertFalse(transport._fatal_error.called) + + def test__loop_writing_error_received_connection(self): + self.proactor.send.side_effect = ConnectionRefusedError + + transport = self.datagram_transport(address=('0.0.0.0', 1)) + transport._fatal_error = mock.Mock() + transport._buffer.append((b'data', ())) + transport._loop_writing() + + self.assertFalse(transport._fatal_error.called) + self.assertTrue(self.protocol.error_received.called) + + @mock.patch('asyncio.base_events.logger.error') + def test_fatal_error_connected(self, m_exc): + transport = self.datagram_transport(address=('0.0.0.0', 1)) + err = ConnectionRefusedError() + transport._fatal_error(err) + self.assertFalse(self.protocol.error_received.called) + m_exc.assert_not_called() + + +class BaseProactorEventLoopTests(test_utils.TestCase): + + def setUp(self): + super().setUp() + + self.sock = test_utils.mock_nonblocking_socket() + self.proactor = mock.Mock() + + self.ssock, self.csock = mock.Mock(), mock.Mock() + + with mock.patch('asyncio.proactor_events.socket.socketpair', + return_value=(self.ssock, self.csock)): + with mock.patch('signal.set_wakeup_fd'): + self.loop = BaseProactorEventLoop(self.proactor) + self.set_event_loop(self.loop) + + @mock.patch('asyncio.proactor_events.socket.socketpair') + def test_ctor(self, socketpair): + ssock, csock = socketpair.return_value = ( + mock.Mock(), mock.Mock()) + with mock.patch('signal.set_wakeup_fd'): + loop = BaseProactorEventLoop(self.proactor) + self.assertIs(loop._ssock, ssock) + self.assertIs(loop._csock, csock) + self.assertEqual(loop._internal_fds, 1) + loop.close() + + def test_close_self_pipe(self): + self.loop._close_self_pipe() + self.assertEqual(self.loop._internal_fds, 0) + self.assertTrue(self.ssock.close.called) + self.assertTrue(self.csock.close.called) + self.assertIsNone(self.loop._ssock) + self.assertIsNone(self.loop._csock) + + # Don't call close(): _close_self_pipe() cannot be called twice + self.loop._closed = True + + def test_close(self): + self.loop._close_self_pipe = mock.Mock() + self.loop.close() + self.assertTrue(self.loop._close_self_pipe.called) + self.assertTrue(self.proactor.close.called) + self.assertIsNone(self.loop._proactor) + + self.loop._close_self_pipe.reset_mock() + self.loop.close() + self.assertFalse(self.loop._close_self_pipe.called) + + def test_make_socket_transport(self): + tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol()) + self.assertIsInstance(tr, _ProactorSocketTransport) + close_transport(tr) + + def test_loop_self_reading(self): + self.loop._loop_self_reading() + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_fut(self): + fut = mock.Mock() + self.loop._self_reading_future = fut + self.loop._loop_self_reading(fut) + self.assertTrue(fut.result.called) + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_exception(self): + self.loop.call_exception_handler = mock.Mock() + self.proactor.recv.side_effect = OSError() + self.loop._loop_self_reading() + self.assertTrue(self.loop.call_exception_handler.called) + + def test_write_to_self(self): + self.loop._write_to_self() + self.csock.send.assert_called_with(b'\0') + + def test_process_events(self): + self.loop._process_events([]) + + @mock.patch('asyncio.base_events.logger') + def test_create_server(self, m_log): + pf = mock.Mock() + call_soon = self.loop.call_soon = mock.Mock() + + self.loop._start_serving(pf, self.sock) + self.assertTrue(call_soon.called) + + # callback + loop = call_soon.call_args[0][0] + loop() + self.proactor.accept.assert_called_with(self.sock) + + # conn + fut = mock.Mock() + fut.result.return_value = (mock.Mock(), mock.Mock()) + + make_tr = self.loop._make_socket_transport = mock.Mock() + loop(fut) + self.assertTrue(fut.result.called) + self.assertTrue(make_tr.called) + + # exception + fut.result.side_effect = OSError() + loop(fut) + self.assertTrue(self.sock.close.called) + self.assertTrue(m_log.error.called) + + def test_create_server_cancel(self): + pf = mock.Mock() + call_soon = self.loop.call_soon = mock.Mock() + + self.loop._start_serving(pf, self.sock) + loop = call_soon.call_args[0][0] + + # cancelled + fut = self.loop.create_future() + fut.cancel() + loop(fut) + self.assertTrue(self.sock.close.called) + + def test_stop_serving(self): + sock1 = mock.Mock() + future1 = mock.Mock() + sock2 = mock.Mock() + future2 = mock.Mock() + self.loop._accept_futures = { + sock1.fileno(): future1, + sock2.fileno(): future2 + } + + self.loop._stop_serving(sock1) + self.assertTrue(sock1.close.called) + self.assertTrue(future1.cancel.called) + self.proactor._stop_serving.assert_called_with(sock1) + self.assertFalse(sock2.close.called) + self.assertFalse(future2.cancel.called) + + def datagram_transport(self): + self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol) + return self.loop._make_datagram_transport(self.sock, self.protocol) + + def test_make_datagram_transport(self): + tr = self.datagram_transport() + self.assertIsInstance(tr, _ProactorDatagramTransport) + self.assertIsInstance(tr, asyncio.DatagramTransport) + close_transport(tr) + + def test_datagram_loop_writing(self): + tr = self.datagram_transport() + tr._buffer.appendleft((b'data', ('127.0.0.1', 12068))) + tr._loop_writing() + self.loop._proactor.sendto.assert_called_with(self.sock, b'data', addr=('127.0.0.1', 12068)) + self.loop._proactor.sendto.return_value.add_done_callback.\ + assert_called_with(tr._loop_writing) + + close_transport(tr) + + def test_datagram_loop_reading(self): + tr = self.datagram_transport() + tr._loop_reading() + self.loop._proactor.recvfrom.assert_called_with(self.sock, 256 * 1024) + self.assertFalse(self.protocol.datagram_received.called) + self.assertFalse(self.protocol.error_received.called) + close_transport(tr) + + def test_datagram_loop_reading_data(self): + res = self.loop.create_future() + res.set_result((b'data', ('127.0.0.1', 12068))) + + tr = self.datagram_transport() + tr._read_fut = res + tr._loop_reading(res) + self.loop._proactor.recvfrom.assert_called_with(self.sock, 256 * 1024) + self.protocol.datagram_received.assert_called_with(b'data', ('127.0.0.1', 12068)) + close_transport(tr) + + @unittest.skipIf(sys.flags.optimize, "Assertions are disabled in optimized mode") + def test_datagram_loop_reading_no_data(self): + res = self.loop.create_future() + res.set_result((b'', ('127.0.0.1', 12068))) + + tr = self.datagram_transport() + self.assertRaises(AssertionError, tr._loop_reading, res) + + tr.close = mock.Mock() + tr._read_fut = res + tr._loop_reading(res) + self.assertTrue(self.loop._proactor.recvfrom.called) + self.assertFalse(self.protocol.error_received.called) + self.assertFalse(tr.close.called) + close_transport(tr) + + def test_datagram_loop_reading_aborted(self): + err = self.loop._proactor.recvfrom.side_effect = ConnectionAbortedError() + + tr = self.datagram_transport() + tr._fatal_error = mock.Mock() + tr._protocol.error_received = mock.Mock() + tr._loop_reading() + tr._protocol.error_received.assert_called_with(err) + close_transport(tr) + + def test_datagram_loop_writing_aborted(self): + err = self.loop._proactor.sendto.side_effect = ConnectionAbortedError() + + tr = self.datagram_transport() + tr._fatal_error = mock.Mock() + tr._protocol.error_received = mock.Mock() + tr._buffer.appendleft((b'Hello', ('127.0.0.1', 12068))) + tr._loop_writing() + tr._protocol.error_received.assert_called_with(err) + close_transport(tr) + + +@unittest.skipIf(sys.platform != 'win32', + 'Proactor is supported on Windows only') +class ProactorEventLoopUnixSockSendfileTests(test_utils.TestCase): + DATA = b"12345abcde" * 16 * 1024 # 160 KiB + + class MyProto(asyncio.Protocol): + + def __init__(self, loop): + self.started = False + self.closed = False + self.data = bytearray() + self.fut = loop.create_future() + self.transport = None + + def connection_made(self, transport): + self.started = True + self.transport = transport + + def data_received(self, data): + self.data.extend(data) + + def connection_lost(self, exc): + self.closed = True + self.fut.set_result(None) + + async def wait_closed(self): + await self.fut + + @classmethod + def setUpClass(cls): + with open(os_helper.TESTFN, 'wb') as fp: + fp.write(cls.DATA) + super().setUpClass() + + @classmethod + def tearDownClass(cls): + os_helper.unlink(os_helper.TESTFN) + super().tearDownClass() + + def setUp(self): + self.loop = asyncio.ProactorEventLoop() + self.set_event_loop(self.loop) + self.addCleanup(self.loop.close) + self.file = open(os_helper.TESTFN, 'rb') + self.addCleanup(self.file.close) + super().setUp() + + def make_socket(self, cleanup=True, blocking=False): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setblocking(blocking) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024) + if cleanup: + self.addCleanup(sock.close) + return sock + + def run_loop(self, coro): + return self.loop.run_until_complete(coro) + + def prepare(self): + sock = self.make_socket() + proto = self.MyProto(self.loop) + port = socket_helper.find_unused_port() + srv_sock = self.make_socket(cleanup=False) + srv_sock.bind(('127.0.0.1', port)) + server = self.run_loop(self.loop.create_server( + lambda: proto, sock=srv_sock)) + self.run_loop(self.loop.sock_connect(sock, srv_sock.getsockname())) + + def cleanup(): + if proto.transport is not None: + # can be None if the task was cancelled before + # connection_made callback + proto.transport.close() + self.run_loop(proto.wait_closed()) + + server.close() + self.run_loop(server.wait_closed()) + + self.addCleanup(cleanup) + + return sock, proto + + def test_sock_sendfile_not_a_file(self): + sock, proto = self.prepare() + f = object() + with self.assertRaisesRegex(asyncio.SendfileNotAvailableError, + "not a regular file"): + self.run_loop(self.loop._sock_sendfile_native(sock, f, + 0, None)) + self.assertEqual(self.file.tell(), 0) + + def test_sock_sendfile_iobuffer(self): + sock, proto = self.prepare() + f = io.BytesIO() + with self.assertRaisesRegex(asyncio.SendfileNotAvailableError, + "not a regular file"): + self.run_loop(self.loop._sock_sendfile_native(sock, f, + 0, None)) + self.assertEqual(self.file.tell(), 0) + + def test_sock_sendfile_not_regular_file(self): + sock, proto = self.prepare() + f = mock.Mock() + f.fileno.return_value = -1 + with self.assertRaisesRegex(asyncio.SendfileNotAvailableError, + "not a regular file"): + self.run_loop(self.loop._sock_sendfile_native(sock, f, + 0, None)) + self.assertEqual(self.file.tell(), 0) + + def test_blocking_socket(self): + self.loop.set_debug(True) + sock = self.make_socket(blocking=True) + with self.assertRaisesRegex(ValueError, "must be non-blocking"): + self.run_loop(self.loop.sock_sendfile(sock, self.file)) + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_protocols.py b/Lib/test/test_asyncio/test_protocols.py new file mode 100644 index 00000000000..0f232631867 --- /dev/null +++ b/Lib/test/test_asyncio/test_protocols.py @@ -0,0 +1,67 @@ +import unittest +from unittest import mock + +import asyncio + + +def tearDownModule(): + # not needed for the test file but added for uniformness with all other + # asyncio test files for the sake of unified cleanup + asyncio.set_event_loop_policy(None) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_base_protocol(self): + f = mock.Mock() + p = asyncio.BaseProtocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.pause_writing()) + self.assertIsNone(p.resume_writing()) + self.assertFalse(hasattr(p, '__dict__')) + + def test_protocol(self): + f = mock.Mock() + p = asyncio.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + self.assertIsNone(p.pause_writing()) + self.assertIsNone(p.resume_writing()) + self.assertFalse(hasattr(p, '__dict__')) + + def test_buffered_protocol(self): + f = mock.Mock() + p = asyncio.BufferedProtocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.get_buffer(100)) + self.assertIsNone(p.buffer_updated(150)) + self.assertIsNone(p.pause_writing()) + self.assertIsNone(p.resume_writing()) + self.assertFalse(hasattr(p, '__dict__')) + + def test_datagram_protocol(self): + f = mock.Mock() + dp = asyncio.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.error_received(f)) + self.assertIsNone(dp.datagram_received(f, f)) + self.assertFalse(hasattr(dp, '__dict__')) + + def test_subprocess_protocol(self): + f = mock.Mock() + sp = asyncio.SubprocessProtocol() + self.assertIsNone(sp.connection_made(f)) + self.assertIsNone(sp.connection_lost(f)) + self.assertIsNone(sp.pipe_data_received(1, f)) + self.assertIsNone(sp.pipe_connection_lost(1, f)) + self.assertIsNone(sp.process_exited()) + self.assertFalse(hasattr(sp, '__dict__')) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_queues.py b/Lib/test/test_asyncio/test_queues.py new file mode 100644 index 00000000000..a42432a85e2 --- /dev/null +++ b/Lib/test/test_asyncio/test_queues.py @@ -0,0 +1,729 @@ +"""Tests for queues.py""" + +import asyncio +import unittest +from types import GenericAlias + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class QueueBasicTests(unittest.IsolatedAsyncioTestCase): + + async def _test_repr_or_str(self, fn, expect_id): + """Test Queue's repr or str. + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). + """ + q = asyncio.Queue() + self.assertTrue(fn(q).startswith(' 0 + + self.assertEqual(q.qsize(), 0) + + # Ensure join() task successfully finishes + await q.join() + + # Ensure get() task is finished, and raised ShutDown + await asyncio.sleep(0) + self.assertTrue(get_task.done()) + with self.assertRaisesShutdown(): + await get_task + + # Ensure put() and get() raise ShutDown + with self.assertRaisesShutdown(): + await q.put("data") + with self.assertRaisesShutdown(): + q.put_nowait("data") + + with self.assertRaisesShutdown(): + await q.get() + with self.assertRaisesShutdown(): + q.get_nowait() + + async def test_shutdown_nonempty(self): + # Test shutting down a non-empty queue + + # Setup full queue with 1 item, and join() and put() tasks + q = self.q_class(maxsize=1) + loop = asyncio.get_running_loop() + + q.put_nowait("data") + join_task = loop.create_task(q.join()) + put_task = loop.create_task(q.put("data2")) + + # Ensure put() task is not finished + await asyncio.sleep(0) + self.assertFalse(put_task.done()) + + # Perform shut-down + q.shutdown(immediate=False) # unfinished tasks: 1 -> 1 + + self.assertEqual(q.qsize(), 1) + + # Ensure put() task is finished, and raised ShutDown + await asyncio.sleep(0) + self.assertTrue(put_task.done()) + with self.assertRaisesShutdown(): + await put_task + + # Ensure get() succeeds on enqueued item + self.assertEqual(await q.get(), "data") + + # Ensure join() task is not finished + await asyncio.sleep(0) + self.assertFalse(join_task.done()) + + # Ensure put() and get() raise ShutDown + with self.assertRaisesShutdown(): + await q.put("data") + with self.assertRaisesShutdown(): + q.put_nowait("data") + + with self.assertRaisesShutdown(): + await q.get() + with self.assertRaisesShutdown(): + q.get_nowait() + + # Ensure there is 1 unfinished task, and join() task succeeds + q.task_done() + + await asyncio.sleep(0) + self.assertTrue(join_task.done()) + await join_task + + with self.assertRaises( + ValueError, msg="Didn't appear to mark all tasks done" + ): + q.task_done() + + async def test_shutdown_immediate(self): + # Test immediately shutting down a queue + + # Setup queue with 1 item, and a join() task + q = self.q_class() + loop = asyncio.get_running_loop() + q.put_nowait("data") + join_task = loop.create_task(q.join()) + + # Perform shut-down + q.shutdown(immediate=True) # unfinished tasks: 1 -> 0 + + self.assertEqual(q.qsize(), 0) + + # Ensure join() task has successfully finished + await asyncio.sleep(0) + self.assertTrue(join_task.done()) + await join_task + + # Ensure put() and get() raise ShutDown + with self.assertRaisesShutdown(): + await q.put("data") + with self.assertRaisesShutdown(): + q.put_nowait("data") + + with self.assertRaisesShutdown(): + await q.get() + with self.assertRaisesShutdown(): + q.get_nowait() + + # Ensure there are no unfinished tasks + with self.assertRaises( + ValueError, msg="Didn't appear to mark all tasks done" + ): + q.task_done() + + async def test_shutdown_immediate_with_unfinished(self): + # Test immediately shutting down a queue with unfinished tasks + + # Setup queue with 2 items (1 retrieved), and a join() task + q = self.q_class() + loop = asyncio.get_running_loop() + q.put_nowait("data") + q.put_nowait("data") + join_task = loop.create_task(q.join()) + self.assertEqual(await q.get(), "data") + + # Perform shut-down + q.shutdown(immediate=True) # unfinished tasks: 2 -> 1 + + self.assertEqual(q.qsize(), 0) + + # Ensure join() task is not finished + await asyncio.sleep(0) + self.assertFalse(join_task.done()) + + # Ensure put() and get() raise ShutDown + with self.assertRaisesShutdown(): + await q.put("data") + with self.assertRaisesShutdown(): + q.put_nowait("data") + + with self.assertRaisesShutdown(): + await q.get() + with self.assertRaisesShutdown(): + q.get_nowait() + + # Ensure there is 1 unfinished task + q.task_done() + with self.assertRaises( + ValueError, msg="Didn't appear to mark all tasks done" + ): + q.task_done() + + # Ensure join() task has successfully finished + await asyncio.sleep(0) + self.assertTrue(join_task.done()) + await join_task + +@unittest.skip('TODO: RUSTPYTHON') +# AttributeError: 'Queue' object has no attribute 'shutdown' +class QueueShutdownTests( + _QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase +): + q_class = asyncio.Queue + + +@unittest.skip('TODO: RUSTPYTHON') +# AttributeError: 'LifoQueue' object has no attribute 'shutdown' +class LifoQueueShutdownTests( + _QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase +): + q_class = asyncio.LifoQueue + +@unittest.skip('TODO: RUSTPYTHON') +# AttributeError: 'PriorityQueue' object has no attribute 'shutdown' +class PriorityQueueShutdownTests( + _QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase +): + q_class = asyncio.PriorityQueue + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_runners.py b/Lib/test/test_asyncio/test_runners.py new file mode 100644 index 00000000000..1c4eb2f8ccb --- /dev/null +++ b/Lib/test/test_asyncio/test_runners.py @@ -0,0 +1,526 @@ +import _thread +import asyncio +import contextvars +import re +import signal +import sys +import threading +import unittest +from test.test_asyncio import utils as test_utils +from unittest import mock +from unittest.mock import patch + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +def interrupt_self(): + _thread.interrupt_main() + + +class TestPolicy(asyncio.AbstractEventLoopPolicy): + + def __init__(self, loop_factory): + self.loop_factory = loop_factory + self.loop = None + + def get_event_loop(self): + # shouldn't ever be called by asyncio.run() + raise RuntimeError + + def new_event_loop(self): + return self.loop_factory() + + def set_event_loop(self, loop): + if loop is not None: + # we want to check if the loop is closed + # in BaseTest.tearDown + self.loop = loop + + +class BaseTest(unittest.TestCase): + + def new_loop(self): + loop = asyncio.BaseEventLoop() + loop._process_events = mock.Mock() + # Mock waking event loop from select + loop._write_to_self = mock.Mock() + loop._write_to_self.return_value = None + loop._selector = mock.Mock() + loop._selector.select.return_value = () + loop.shutdown_ag_run = False + + async def shutdown_asyncgens(): + loop.shutdown_ag_run = True + loop.shutdown_asyncgens = shutdown_asyncgens + + return loop + + def setUp(self): + super().setUp() + + policy = TestPolicy(self.new_loop) + asyncio.set_event_loop_policy(policy) + + def tearDown(self): + policy = asyncio.get_event_loop_policy() + if policy.loop is not None: + self.assertTrue(policy.loop.is_closed()) + self.assertTrue(policy.loop.shutdown_ag_run) + + asyncio.set_event_loop_policy(None) + super().tearDown() + + +class RunTests(BaseTest): + + def test_asyncio_run_return(self): + async def main(): + await asyncio.sleep(0) + return 42 + + self.assertEqual(asyncio.run(main()), 42) + + def test_asyncio_run_raises(self): + async def main(): + await asyncio.sleep(0) + raise ValueError('spam') + + with self.assertRaisesRegex(ValueError, 'spam'): + asyncio.run(main()) + + def test_asyncio_run_only_coro(self): + for o in {1, lambda: None}: + with self.subTest(obj=o), \ + self.assertRaisesRegex(ValueError, + 'a coroutine was expected'): + asyncio.run(o) + + def test_asyncio_run_debug(self): + async def main(expected): + loop = asyncio.get_event_loop() + self.assertIs(loop.get_debug(), expected) + + asyncio.run(main(False), debug=False) + asyncio.run(main(True), debug=True) + with mock.patch('asyncio.coroutines._is_debug_mode', lambda: True): + asyncio.run(main(True)) + asyncio.run(main(False), debug=False) + with mock.patch('asyncio.coroutines._is_debug_mode', lambda: False): + asyncio.run(main(True), debug=True) + asyncio.run(main(False)) + + def test_asyncio_run_from_running_loop(self): + async def main(): + coro = main() + try: + asyncio.run(coro) + finally: + coro.close() # Suppress ResourceWarning + + with self.assertRaisesRegex(RuntimeError, + 'cannot be called from a running'): + asyncio.run(main()) + + def test_asyncio_run_cancels_hanging_tasks(self): + lo_task = None + + async def leftover(): + await asyncio.sleep(0.1) + + async def main(): + nonlocal lo_task + lo_task = asyncio.create_task(leftover()) + return 123 + + self.assertEqual(asyncio.run(main()), 123) + self.assertTrue(lo_task.done()) + + def test_asyncio_run_reports_hanging_tasks_errors(self): + lo_task = None + call_exc_handler_mock = mock.Mock() + + async def leftover(): + try: + await asyncio.sleep(0.1) + except asyncio.CancelledError: + 1 / 0 + + async def main(): + loop = asyncio.get_running_loop() + loop.call_exception_handler = call_exc_handler_mock + + nonlocal lo_task + lo_task = asyncio.create_task(leftover()) + return 123 + + self.assertEqual(asyncio.run(main()), 123) + self.assertTrue(lo_task.done()) + + call_exc_handler_mock.assert_called_with({ + 'message': test_utils.MockPattern(r'asyncio.run.*shutdown'), + 'task': lo_task, + 'exception': test_utils.MockInstanceOf(ZeroDivisionError) + }) + + # TODO: RUSTPYTHON + # AssertionError: is not None + @unittest.expectedFailure + def test_asyncio_run_closes_gens_after_hanging_tasks_errors(self): + spinner = None + lazyboy = None + + class FancyExit(Exception): + pass + + async def fidget(): + while True: + yield 1 + await asyncio.sleep(1) + + async def spin(): + nonlocal spinner + spinner = fidget() + try: + async for the_meaning_of_life in spinner: # NoQA + pass + except asyncio.CancelledError: + 1 / 0 + + async def main(): + loop = asyncio.get_running_loop() + loop.call_exception_handler = mock.Mock() + + nonlocal lazyboy + lazyboy = asyncio.create_task(spin()) + raise FancyExit + + with self.assertRaises(FancyExit): + asyncio.run(main()) + + self.assertTrue(lazyboy.done()) + + self.assertIsNone(spinner.ag_frame) + self.assertFalse(spinner.ag_running) + + def test_asyncio_run_set_event_loop(self): + #See https://github.com/python/cpython/issues/93896 + + async def main(): + await asyncio.sleep(0) + return 42 + + policy = asyncio.get_event_loop_policy() + policy.set_event_loop = mock.Mock() + asyncio.run(main()) + self.assertTrue(policy.set_event_loop.called) + + def test_asyncio_run_without_uncancel(self): + # See https://github.com/python/cpython/issues/95097 + class Task: + def __init__(self, loop, coro, **kwargs): + self._task = asyncio.Task(coro, loop=loop, **kwargs) + + def cancel(self, *args, **kwargs): + return self._task.cancel(*args, **kwargs) + + def add_done_callback(self, *args, **kwargs): + return self._task.add_done_callback(*args, **kwargs) + + def remove_done_callback(self, *args, **kwargs): + return self._task.remove_done_callback(*args, **kwargs) + + @property + def _asyncio_future_blocking(self): + return self._task._asyncio_future_blocking + + def result(self, *args, **kwargs): + return self._task.result(*args, **kwargs) + + def done(self, *args, **kwargs): + return self._task.done(*args, **kwargs) + + def cancelled(self, *args, **kwargs): + return self._task.cancelled(*args, **kwargs) + + def exception(self, *args, **kwargs): + return self._task.exception(*args, **kwargs) + + def get_loop(self, *args, **kwargs): + return self._task.get_loop(*args, **kwargs) + + def set_name(self, *args, **kwargs): + return self._task.set_name(*args, **kwargs) + + async def main(): + interrupt_self() + await asyncio.Event().wait() + + def new_event_loop(): + loop = self.new_loop() + loop.set_task_factory(Task) + return loop + + asyncio.set_event_loop_policy(TestPolicy(new_event_loop)) + with self.assertRaises(asyncio.CancelledError): + asyncio.run(main()) + + def test_asyncio_run_loop_factory(self): + factory = mock.Mock() + loop = factory.return_value = self.new_loop() + + async def main(): + self.assertEqual(asyncio.get_running_loop(), loop) + + asyncio.run(main(), loop_factory=factory) + factory.assert_called_once_with() + + @unittest.skip('TODO: RUSTPYTHON') + # module 'asyncio' has no attribute 'EventLoop' + def test_loop_factory_default_event_loop(self): + async def main(): + if sys.platform == "win32": + self.assertIsInstance(asyncio.get_running_loop(), asyncio.ProactorEventLoop) + else: + self.assertIsInstance(asyncio.get_running_loop(), asyncio.SelectorEventLoop) + + + asyncio.run(main(), loop_factory=asyncio.EventLoop) + + +class RunnerTests(BaseTest): + + def test_non_debug(self): + with asyncio.Runner(debug=False) as runner: + self.assertFalse(runner.get_loop().get_debug()) + + def test_debug(self): + with asyncio.Runner(debug=True) as runner: + self.assertTrue(runner.get_loop().get_debug()) + + def test_custom_factory(self): + loop = mock.Mock() + with asyncio.Runner(loop_factory=lambda: loop) as runner: + self.assertIs(runner.get_loop(), loop) + + def test_run(self): + async def f(): + await asyncio.sleep(0) + return 'done' + + with asyncio.Runner() as runner: + self.assertEqual('done', runner.run(f())) + loop = runner.get_loop() + + with self.assertRaisesRegex( + RuntimeError, + "Runner is closed" + ): + runner.get_loop() + + self.assertTrue(loop.is_closed()) + + def test_run_non_coro(self): + with asyncio.Runner() as runner: + with self.assertRaisesRegex( + ValueError, + "a coroutine was expected" + ): + runner.run(123) + + def test_run_future(self): + with asyncio.Runner() as runner: + with self.assertRaisesRegex( + ValueError, + "a coroutine was expected" + ): + fut = runner.get_loop().create_future() + runner.run(fut) + + def test_explicit_close(self): + runner = asyncio.Runner() + loop = runner.get_loop() + runner.close() + with self.assertRaisesRegex( + RuntimeError, + "Runner is closed" + ): + runner.get_loop() + + self.assertTrue(loop.is_closed()) + + def test_double_close(self): + runner = asyncio.Runner() + loop = runner.get_loop() + + runner.close() + self.assertTrue(loop.is_closed()) + + # the second call is no-op + runner.close() + self.assertTrue(loop.is_closed()) + + def test_second_with_block_raises(self): + ret = [] + + async def f(arg): + ret.append(arg) + + runner = asyncio.Runner() + with runner: + runner.run(f(1)) + + with self.assertRaisesRegex( + RuntimeError, + "Runner is closed" + ): + with runner: + runner.run(f(2)) + + self.assertEqual([1], ret) + + def test_run_keeps_context(self): + cvar = contextvars.ContextVar("cvar", default=-1) + + async def f(val): + old = cvar.get() + await asyncio.sleep(0) + cvar.set(val) + return old + + async def get_context(): + return contextvars.copy_context() + + with asyncio.Runner() as runner: + self.assertEqual(-1, runner.run(f(1))) + self.assertEqual(1, runner.run(f(2))) + + self.assertEqual(2, runner.run(get_context()).get(cvar)) + + # TODO: RUSTPYTHON + # AssertionError: RuntimeWarning not triggered + @unittest.expectedFailure + def test_recursive_run(self): + async def g(): + pass + + async def f(): + runner.run(g()) + + with asyncio.Runner() as runner: + with self.assertWarnsRegex( + RuntimeWarning, + "coroutine .+ was never awaited", + ): + with self.assertRaisesRegex( + RuntimeError, + re.escape( + "Runner.run() cannot be called from a running event loop" + ), + ): + runner.run(f()) + + def test_interrupt_call_soon(self): + # The only case when task is not suspended by waiting a future + # or another task + assert threading.current_thread() is threading.main_thread() + + async def coro(): + with self.assertRaises(asyncio.CancelledError): + while True: + await asyncio.sleep(0) + raise asyncio.CancelledError() + + with asyncio.Runner() as runner: + runner.get_loop().call_later(0.1, interrupt_self) + with self.assertRaises(KeyboardInterrupt): + runner.run(coro()) + + def test_interrupt_wait(self): + # interrupting when waiting a future cancels both future and main task + assert threading.current_thread() is threading.main_thread() + + async def coro(fut): + with self.assertRaises(asyncio.CancelledError): + await fut + raise asyncio.CancelledError() + + with asyncio.Runner() as runner: + fut = runner.get_loop().create_future() + runner.get_loop().call_later(0.1, interrupt_self) + + with self.assertRaises(KeyboardInterrupt): + runner.run(coro(fut)) + + self.assertTrue(fut.cancelled()) + + def test_interrupt_cancelled_task(self): + # interrupting cancelled main task doesn't raise KeyboardInterrupt + assert threading.current_thread() is threading.main_thread() + + async def subtask(task): + await asyncio.sleep(0) + task.cancel() + interrupt_self() + + async def coro(): + asyncio.create_task(subtask(asyncio.current_task())) + await asyncio.sleep(10) + + with asyncio.Runner() as runner: + with self.assertRaises(asyncio.CancelledError): + runner.run(coro()) + + def test_signal_install_not_supported_ok(self): + # signal.signal() can throw if the "main thread" doesn't have signals enabled + assert threading.current_thread() is threading.main_thread() + + async def coro(): + pass + + with asyncio.Runner() as runner: + with patch.object( + signal, + "signal", + side_effect=ValueError( + "signal only works in main thread of the main interpreter" + ) + ): + runner.run(coro()) + + def test_set_event_loop_called_once(self): + # See https://github.com/python/cpython/issues/95736 + async def coro(): + pass + + policy = asyncio.get_event_loop_policy() + policy.set_event_loop = mock.Mock() + runner = asyncio.Runner() + runner.run(coro()) + runner.run(coro()) + + self.assertEqual(1, policy.set_event_loop.call_count) + runner.close() + + def test_no_repr_is_call_on_the_task_result(self): + # See https://github.com/python/cpython/issues/112559. + class MyResult: + def __init__(self): + self.repr_count = 0 + def __repr__(self): + self.repr_count += 1 + return super().__repr__() + + async def coro(): + return MyResult() + + + with asyncio.Runner() as runner: + result = runner.run(coro()) + + self.assertEqual(0, result.repr_count) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py new file mode 100644 index 00000000000..65277aed7e0 --- /dev/null +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -0,0 +1,1645 @@ +"""Tests for selector_events.py""" + +import collections +import selectors +import socket +import sys +import unittest +from asyncio import selector_events +from unittest import mock + +try: + import ssl +except ImportError: + ssl = None + +import asyncio +from asyncio.selector_events import (BaseSelectorEventLoop, + _SelectorDatagramTransport, + _SelectorSocketTransport, + _SelectorTransport) +from test.test_asyncio import utils as test_utils + +MOCK_ANY = mock.ANY + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class TestBaseSelectorEventLoop(BaseSelectorEventLoop): + + def _make_self_pipe(self): + self._ssock = mock.Mock() + self._csock = mock.Mock() + self._internal_fds += 1 + + def _close_self_pipe(self): + pass + + +def list_to_buffer(l=()): + buffer = collections.deque() + buffer.extend((memoryview(i) for i in l)) + return buffer + + + +def close_transport(transport): + # Don't call transport.close() because the event loop and the selector + # are mocked + if transport._sock is None: + return + transport._sock.close() + transport._sock = None + + +class BaseSelectorEventLoopTests(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.selector = mock.Mock() + self.selector.select.return_value = [] + self.loop = TestBaseSelectorEventLoop(self.selector) + self.set_event_loop(self.loop) + + def test_make_socket_transport(self): + m = mock.Mock() + self.loop.add_reader = mock.Mock() + self.loop._ensure_fd_no_transport = mock.Mock() + transport = self.loop._make_socket_transport(m, asyncio.Protocol()) + self.assertIsInstance(transport, _SelectorSocketTransport) + self.assertEqual(self.loop._ensure_fd_no_transport.call_count, 1) + + # Calling repr() must not fail when the event loop is closed + self.loop.close() + repr(transport) + + close_transport(transport) + + @mock.patch('asyncio.selector_events.ssl', None) + @mock.patch('asyncio.sslproto.ssl', None) + def test_make_ssl_transport_without_ssl_error(self): + m = mock.Mock() + self.loop.add_reader = mock.Mock() + self.loop.add_writer = mock.Mock() + self.loop.remove_reader = mock.Mock() + self.loop.remove_writer = mock.Mock() + self.loop._ensure_fd_no_transport = mock.Mock() + with self.assertRaises(RuntimeError): + self.loop._make_ssl_transport(m, m, m, m) + self.assertEqual(self.loop._ensure_fd_no_transport.call_count, 1) + + def test_close(self): + class EventLoop(BaseSelectorEventLoop): + def _make_self_pipe(self): + self._ssock = mock.Mock() + self._csock = mock.Mock() + self._internal_fds += 1 + + self.loop = EventLoop(self.selector) + self.set_event_loop(self.loop) + + ssock = self.loop._ssock + ssock.fileno.return_value = 7 + csock = self.loop._csock + csock.fileno.return_value = 1 + remove_reader = self.loop._remove_reader = mock.Mock() + + self.loop._selector.close() + self.loop._selector = selector = mock.Mock() + self.assertFalse(self.loop.is_closed()) + + self.loop.close() + self.assertTrue(self.loop.is_closed()) + self.assertIsNone(self.loop._selector) + self.assertIsNone(self.loop._csock) + self.assertIsNone(self.loop._ssock) + selector.close.assert_called_with() + ssock.close.assert_called_with() + csock.close.assert_called_with() + remove_reader.assert_called_with(7) + + # it should be possible to call close() more than once + self.loop.close() + self.loop.close() + + # operation blocked when the loop is closed + f = self.loop.create_future() + self.assertRaises(RuntimeError, self.loop.run_forever) + self.assertRaises(RuntimeError, self.loop.run_until_complete, f) + fd = 0 + def callback(): + pass + self.assertRaises(RuntimeError, self.loop.add_reader, fd, callback) + self.assertRaises(RuntimeError, self.loop.add_writer, fd, callback) + + def test_close_no_selector(self): + self.loop.remove_reader = mock.Mock() + self.loop._selector.close() + self.loop._selector = None + self.loop.close() + self.assertIsNone(self.loop._selector) + + def test_read_from_self_tryagain(self): + self.loop._ssock.recv.side_effect = BlockingIOError + self.assertIsNone(self.loop._read_from_self()) + + def test_read_from_self_exception(self): + self.loop._ssock.recv.side_effect = OSError + self.assertRaises(OSError, self.loop._read_from_self) + + def test_write_to_self_tryagain(self): + self.loop._csock.send.side_effect = BlockingIOError + with test_utils.disable_logger(): + self.assertIsNone(self.loop._write_to_self()) + + def test_write_to_self_exception(self): + # _write_to_self() swallows OSError + self.loop._csock.send.side_effect = RuntimeError() + self.assertRaises(RuntimeError, self.loop._write_to_self) + + @mock.patch('socket.getaddrinfo') + def test_sock_connect_resolve_using_socket_params(self, m_gai): + addr = ('need-resolution.com', 8080) + for sock_type in [socket.SOCK_STREAM, socket.SOCK_DGRAM]: + with self.subTest(sock_type): + sock = test_utils.mock_nonblocking_socket(type=sock_type) + + m_gai.side_effect = \ + lambda *args: [(None, None, None, None, ('127.0.0.1', 0))] + + con = self.loop.create_task(self.loop.sock_connect(sock, addr)) + self.loop.run_until_complete(con) + m_gai.assert_called_with( + addr[0], addr[1], sock.family, sock.type, sock.proto, 0) + + self.loop.run_until_complete(con) + sock.connect.assert_called_with(('127.0.0.1', 0)) + + @unittest.skip('TODO: RUSTPYTHON') + # TypeError: cannot unpack non-iterable Mock object + def test_add_reader(self): + self.loop._selector.get_map.return_value = {} + cb = lambda: True + self.loop.add_reader(1, cb) + + self.assertTrue(self.loop._selector.register.called) + fd, mask, (r, w) = self.loop._selector.register.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_READ, mask) + self.assertEqual(cb, r._callback) + self.assertIsNone(w) + + @unittest.skip('TODO: RUSTPYTHON') + # TypeError: cannot unpack non-iterable Mock object + def test_add_reader_existing(self): + reader = mock.Mock() + writer = mock.Mock() + self.loop._selector.get_map.return_value = {1: selectors.SelectorKey( + 1, 1, selectors.EVENT_WRITE, (reader, writer))} + cb = lambda: True + self.loop.add_reader(1, cb) + + self.assertTrue(reader.cancel.called) + self.assertFalse(self.loop._selector.register.called) + self.assertTrue(self.loop._selector.modify.called) + fd, mask, (r, w) = self.loop._selector.modify.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) + self.assertEqual(cb, r._callback) + self.assertEqual(writer, w) + + @unittest.skip('TODO: RUSTPYTHON') + # TypeError: cannot unpack non-iterable Mock object + def test_add_reader_existing_writer(self): + writer = mock.Mock() + self.loop._selector.get_map.return_value = {1: selectors.SelectorKey( + 1, 1, selectors.EVENT_WRITE, (None, writer))} + cb = lambda: True + self.loop.add_reader(1, cb) + + self.assertFalse(self.loop._selector.register.called) + self.assertTrue(self.loop._selector.modify.called) + fd, mask, (r, w) = self.loop._selector.modify.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) + self.assertEqual(cb, r._callback) + self.assertEqual(writer, w) + + @unittest.skip('TODO: RUSTPYTHON') + # TypeError: cannot unpack non-iterable Mock object + def test_remove_reader(self): + self.loop._selector.get_map.return_value = {1: selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (None, None))} + self.assertFalse(self.loop.remove_reader(1)) + + self.assertTrue(self.loop._selector.unregister.called) + + @unittest.skip('TODO: RUSTPYTHON') + # TypeError: cannot unpack non-iterable Mock object + def test_remove_reader_read_write(self): + reader = mock.Mock() + writer = mock.Mock() + self.loop._selector.get_map.return_value = {1: selectors.SelectorKey( + 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE, + (reader, writer))} + self.assertTrue( + self.loop.remove_reader(1)) + + self.assertFalse(self.loop._selector.unregister.called) + self.assertEqual( + (1, selectors.EVENT_WRITE, (None, writer)), + self.loop._selector.modify.call_args[0]) + + @unittest.skip('TODO: RUSTPYTHON') + # TypeError: cannot unpack non-iterable Mock object + def test_remove_reader_unknown(self): + self.loop._selector.get_map.return_value = {} + self.assertFalse( + self.loop.remove_reader(1)) + + @unittest.skip('TODO: RUSTPYTHON') + # TypeError: cannot unpack non-iterable Mock object + def test_add_writer(self): + self.loop._selector.get_map.return_value = {} + cb = lambda: True + self.loop.add_writer(1, cb) + + self.assertTrue(self.loop._selector.register.called) + fd, mask, (r, w) = self.loop._selector.register.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE, mask) + self.assertIsNone(r) + self.assertEqual(cb, w._callback) + + @unittest.skip('TODO: RUSTPYTHON') + # TypeError: cannot unpack non-iterable Mock object + def test_add_writer_existing(self): + reader = mock.Mock() + writer = mock.Mock() + self.loop._selector.get_map.return_value = {1: selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (reader, writer))} + cb = lambda: True + self.loop.add_writer(1, cb) + + self.assertTrue(writer.cancel.called) + self.assertFalse(self.loop._selector.register.called) + self.assertTrue(self.loop._selector.modify.called) + fd, mask, (r, w) = self.loop._selector.modify.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) + self.assertEqual(reader, r) + self.assertEqual(cb, w._callback) + + @unittest.skip('TODO: RUSTPYTHON') + # TypeError: cannot unpack non-iterable Mock object + def test_remove_writer(self): + self.loop._selector.get_map.return_value = {1: selectors.SelectorKey( + 1, 1, selectors.EVENT_WRITE, (None, None))} + self.assertFalse(self.loop.remove_writer(1)) + + self.assertTrue(self.loop._selector.unregister.called) + + @unittest.skip('TODO: RUSTPYTHON') + # TypeError: cannot unpack non-iterable Mock object + def test_remove_writer_read_write(self): + reader = mock.Mock() + writer = mock.Mock() + self.loop._selector.get_map.return_value = {1: selectors.SelectorKey( + 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE, + (reader, writer))} + self.assertTrue( + self.loop.remove_writer(1)) + + self.assertFalse(self.loop._selector.unregister.called) + self.assertEqual( + (1, selectors.EVENT_READ, (reader, None)), + self.loop._selector.modify.call_args[0]) + + @unittest.skip('TODO: RUSTPYTHON') + # TypeError: cannot unpack non-iterable Mock object + def test_remove_writer_unknown(self): + self.loop._selector.get_map.return_value = {} + self.assertFalse( + self.loop.remove_writer(1)) + + def test_process_events_read(self): + reader = mock.Mock() + reader._cancelled = False + + self.loop._add_callback = mock.Mock() + self.loop._process_events( + [(selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (reader, None)), + selectors.EVENT_READ)]) + self.assertTrue(self.loop._add_callback.called) + self.loop._add_callback.assert_called_with(reader) + + def test_process_events_read_cancelled(self): + reader = mock.Mock() + reader.cancelled = True + + self.loop._remove_reader = mock.Mock() + self.loop._process_events( + [(selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (reader, None)), + selectors.EVENT_READ)]) + self.loop._remove_reader.assert_called_with(1) + + def test_process_events_write(self): + writer = mock.Mock() + writer._cancelled = False + + self.loop._add_callback = mock.Mock() + self.loop._process_events( + [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE, + (None, writer)), + selectors.EVENT_WRITE)]) + self.loop._add_callback.assert_called_with(writer) + + def test_process_events_write_cancelled(self): + writer = mock.Mock() + writer.cancelled = True + self.loop._remove_writer = mock.Mock() + + self.loop._process_events( + [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE, + (None, writer)), + selectors.EVENT_WRITE)]) + self.loop._remove_writer.assert_called_with(1) + + def test_accept_connection_multiple(self): + sock = mock.Mock() + sock.accept.return_value = (mock.Mock(), mock.Mock()) + backlog = 100 + # Mock the coroutine generation for a connection to prevent + # warnings related to un-awaited coroutines. _accept_connection2 + # is an async function that is patched with AsyncMock. create_task + # creates a task out of coroutine returned by AsyncMock, so use + # asyncio.sleep(0) to ensure created tasks are complete to avoid + # task pending warnings. + mock_obj = mock.patch.object + with mock_obj(self.loop, '_accept_connection2') as accept2_mock: + self.loop._accept_connection( + mock.Mock(), sock, backlog=backlog) + self.loop.run_until_complete(asyncio.sleep(0)) + self.assertEqual(sock.accept.call_count, backlog) + + +class SelectorTransportTests(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = self.new_test_loop() + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + self.sock = mock.Mock(socket.socket) + self.sock.fileno.return_value = 7 + + def create_transport(self): + transport = _SelectorTransport(self.loop, self.sock, self.protocol, + None) + self.addCleanup(close_transport, transport) + return transport + + def test_ctor(self): + tr = self.create_transport() + self.assertIs(tr._loop, self.loop) + self.assertIs(tr._sock, self.sock) + self.assertIs(tr._sock_fd, 7) + + def test_abort(self): + tr = self.create_transport() + tr._force_close = mock.Mock() + + tr.abort() + tr._force_close.assert_called_with(None) + + def test_close(self): + tr = self.create_transport() + tr.close() + + self.assertTrue(tr.is_closing()) + self.assertEqual(1, self.loop.remove_reader_count[7]) + self.protocol.connection_lost(None) + self.assertEqual(tr._conn_lost, 1) + + tr.close() + self.assertEqual(tr._conn_lost, 1) + self.assertEqual(1, self.loop.remove_reader_count[7]) + + def test_close_write_buffer(self): + tr = self.create_transport() + tr._buffer.extend(b'data') + tr.close() + + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_force_close(self): + tr = self.create_transport() + tr._buffer.extend(b'1') + self.loop._add_reader(7, mock.sentinel) + self.loop._add_writer(7, mock.sentinel) + tr._force_close(None) + + self.assertTrue(tr.is_closing()) + self.assertEqual(tr._buffer, list_to_buffer()) + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + + # second close should not remove reader + tr._force_close(None) + self.assertFalse(self.loop.readers) + self.assertEqual(1, self.loop.remove_reader_count[7]) + + @mock.patch('asyncio.log.logger.error') + def test_fatal_error(self, m_exc): + exc = OSError() + tr = self.create_transport() + tr._force_close = mock.Mock() + tr._fatal_error(exc) + + m_exc.assert_not_called() + + tr._force_close.assert_called_with(exc) + + @mock.patch('asyncio.log.logger.error') + def test_fatal_error_custom_exception(self, m_exc): + class MyError(Exception): + pass + exc = MyError() + tr = self.create_transport() + tr._force_close = mock.Mock() + tr._fatal_error(exc) + + m_exc.assert_called_with( + test_utils.MockPattern( + 'Fatal error on transport\nprotocol:.*\ntransport:.*'), + exc_info=(MyError, MOCK_ANY, MOCK_ANY)) + + tr._force_close.assert_called_with(exc) + + def test_connection_lost(self): + exc = OSError() + tr = self.create_transport() + self.assertIsNotNone(tr._protocol) + self.assertIsNotNone(tr._loop) + tr._call_connection_lost(exc) + + self.protocol.connection_lost.assert_called_with(exc) + self.sock.close.assert_called_with() + self.assertIsNone(tr._sock) + + self.assertIsNone(tr._protocol) + self.assertIsNone(tr._loop) + + def test__add_reader(self): + tr = self.create_transport() + tr._buffer.extend(b'1') + tr._add_reader(7, mock.sentinel) + self.assertTrue(self.loop.readers) + + tr._force_close(None) + + self.assertTrue(tr.is_closing()) + self.assertFalse(self.loop.readers) + + # can not add readers after closing + tr._add_reader(7, mock.sentinel) + self.assertFalse(self.loop.readers) + + +class SelectorSocketTransportTests(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = self.new_test_loop() + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + self.sock = mock.Mock(socket.socket) + self.sock_fd = self.sock.fileno.return_value = 7 + + def socket_transport(self, waiter=None, sendmsg=False): + transport = _SelectorSocketTransport(self.loop, self.sock, + self.protocol, waiter=waiter) + if sendmsg: + transport._write_ready = transport._write_sendmsg + else: + transport._write_ready = transport._write_send + self.addCleanup(close_transport, transport) + return transport + + def test_ctor(self): + waiter = self.loop.create_future() + tr = self.socket_transport(waiter=waiter) + self.loop.run_until_complete(waiter) + + self.loop.assert_reader(7, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + waiter = self.loop.create_future() + self.socket_transport(waiter=waiter) + self.loop.run_until_complete(waiter) + + self.assertIsNone(waiter.result()) + + def test_pause_resume_reading(self): + tr = self.socket_transport() + test_utils.run_briefly(self.loop) + self.assertFalse(tr._paused) + self.assertTrue(tr.is_reading()) + self.loop.assert_reader(7, tr._read_ready) + + tr.pause_reading() + tr.pause_reading() + self.assertTrue(tr._paused) + self.assertFalse(tr.is_reading()) + self.loop.assert_no_reader(7) + + tr.resume_reading() + tr.resume_reading() + self.assertFalse(tr._paused) + self.assertTrue(tr.is_reading()) + self.loop.assert_reader(7, tr._read_ready) + + tr.close() + self.assertFalse(tr.is_reading()) + self.loop.assert_no_reader(7) + + def test_pause_reading_connection_made(self): + tr = self.socket_transport() + self.protocol.connection_made.side_effect = lambda _: tr.pause_reading() + test_utils.run_briefly(self.loop) + self.assertFalse(tr.is_reading()) + self.loop.assert_no_reader(7) + + tr.resume_reading() + self.assertTrue(tr.is_reading()) + self.loop.assert_reader(7, tr._read_ready) + + tr.close() + self.assertFalse(tr.is_reading()) + self.loop.assert_no_reader(7) + + + def test_read_eof_received_error(self): + transport = self.socket_transport() + transport.close = mock.Mock() + transport._fatal_error = mock.Mock() + + self.loop.call_exception_handler = mock.Mock() + + self.protocol.eof_received.side_effect = LookupError() + + self.sock.recv.return_value = b'' + transport._read_ready() + + self.protocol.eof_received.assert_called_with() + self.assertTrue(transport._fatal_error.called) + + def test_data_received_error(self): + transport = self.socket_transport() + transport._fatal_error = mock.Mock() + + self.loop.call_exception_handler = mock.Mock() + self.protocol.data_received.side_effect = LookupError() + + self.sock.recv.return_value = b'data' + transport._read_ready() + + self.assertTrue(transport._fatal_error.called) + self.assertTrue(self.protocol.data_received.called) + + def test_read_ready(self): + transport = self.socket_transport() + + self.sock.recv.return_value = b'data' + transport._read_ready() + + self.protocol.data_received.assert_called_with(b'data') + + def test_read_ready_eof(self): + transport = self.socket_transport() + transport.close = mock.Mock() + + self.sock.recv.return_value = b'' + transport._read_ready() + + self.protocol.eof_received.assert_called_with() + transport.close.assert_called_with() + + def test_read_ready_eof_keep_open(self): + transport = self.socket_transport() + transport.close = mock.Mock() + + self.sock.recv.return_value = b'' + self.protocol.eof_received.return_value = True + transport._read_ready() + + self.protocol.eof_received.assert_called_with() + self.assertFalse(transport.close.called) + + @mock.patch('logging.exception') + def test_read_ready_tryagain(self, m_exc): + self.sock.recv.side_effect = BlockingIOError + + transport = self.socket_transport() + transport._fatal_error = mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + @mock.patch('logging.exception') + def test_read_ready_tryagain_interrupted(self, m_exc): + self.sock.recv.side_effect = InterruptedError + + transport = self.socket_transport() + transport._fatal_error = mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + @mock.patch('logging.exception') + def test_read_ready_conn_reset(self, m_exc): + err = self.sock.recv.side_effect = ConnectionResetError() + + transport = self.socket_transport() + transport._force_close = mock.Mock() + with test_utils.disable_logger(): + transport._read_ready() + transport._force_close.assert_called_with(err) + + @mock.patch('logging.exception') + def test_read_ready_err(self, m_exc): + err = self.sock.recv.side_effect = OSError() + + transport = self.socket_transport() + transport._fatal_error = mock.Mock() + transport._read_ready() + + transport._fatal_error.assert_called_with( + err, + 'Fatal read error on socket transport') + + def test_write(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = self.socket_transport() + transport.write(data) + self.sock.send.assert_called_with(data) + + def test_write_bytearray(self): + data = bytearray(b'data') + self.sock.send.return_value = len(data) + + transport = self.socket_transport() + transport.write(data) + self.sock.send.assert_called_with(data) + self.assertEqual(data, bytearray(b'data')) # Hasn't been mutated. + + def test_write_memoryview(self): + data = memoryview(b'data') + self.sock.send.return_value = len(data) + + transport = self.socket_transport() + transport.write(data) + self.sock.send.assert_called_with(data) + + def test_write_no_data(self): + transport = self.socket_transport() + transport._buffer.append(memoryview(b'data')) + transport.write(b'') + self.assertFalse(self.sock.send.called) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) + + def test_write_buffer(self): + transport = self.socket_transport() + transport._buffer.append(b'data1') + transport.write(b'data2') + self.assertFalse(self.sock.send.called) + self.assertEqual(list_to_buffer([b'data1', b'data2']), + transport._buffer) + + def test_write_partial(self): + data = b'data' + self.sock.send.return_value = 2 + + transport = self.socket_transport() + transport.write(data) + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(list_to_buffer([b'ta']), transport._buffer) + + def test_write_partial_bytearray(self): + data = bytearray(b'data') + self.sock.send.return_value = 2 + + transport = self.socket_transport() + transport.write(data) + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(list_to_buffer([b'ta']), transport._buffer) + self.assertEqual(data, bytearray(b'data')) # Hasn't been mutated. + + def test_write_partial_memoryview(self): + data = memoryview(b'data') + self.sock.send.return_value = 2 + + transport = self.socket_transport() + transport.write(data) + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(list_to_buffer([b'ta']), transport._buffer) + + def test_write_partial_none(self): + data = b'data' + self.sock.send.return_value = 0 + self.sock.fileno.return_value = 7 + + transport = self.socket_transport() + transport.write(data) + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) + + def test_write_tryagain(self): + self.sock.send.side_effect = BlockingIOError + + data = b'data' + transport = self.socket_transport() + transport.write(data) + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) + + def test_write_sendmsg_no_data(self): + self.sock.sendmsg = mock.Mock() + self.sock.sendmsg.return_value = 0 + transport = self.socket_transport(sendmsg=True) + transport._buffer.append(memoryview(b'data')) + transport.write(b'') + self.assertFalse(self.sock.sendmsg.called) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) + + @unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg') + def test_writelines_sendmsg_full(self): + data = memoryview(b'data') + self.sock.sendmsg = mock.Mock() + self.sock.sendmsg.return_value = len(data) + + transport = self.socket_transport(sendmsg=True) + transport.writelines([data]) + self.assertTrue(self.sock.sendmsg.called) + self.assertFalse(self.loop.writers) + + @unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg') + def test_writelines_sendmsg_partial(self): + data = memoryview(b'data') + self.sock.sendmsg = mock.Mock() + self.sock.sendmsg.return_value = 2 + + transport = self.socket_transport(sendmsg=True) + transport.writelines([data]) + self.assertTrue(self.sock.sendmsg.called) + self.assertTrue(self.loop.writers) + + def test_writelines_send_full(self): + data = memoryview(b'data') + self.sock.send.return_value = len(data) + self.sock.send.fileno.return_value = 7 + + transport = self.socket_transport() + transport.writelines([data]) + self.assertTrue(self.sock.send.called) + self.assertFalse(self.loop.writers) + + def test_writelines_send_partial(self): + data = memoryview(b'data') + self.sock.send.return_value = 2 + self.sock.send.fileno.return_value = 7 + + transport = self.socket_transport() + transport.writelines([data]) + self.assertTrue(self.sock.send.called) + self.assertTrue(self.loop.writers) + + # TODO: RUSTPYTHON + # AssertionError: False is not true + @unittest.expectedFailure + def test_writelines_pauses_protocol(self): + data = memoryview(b'data') + self.sock.send.return_value = 2 + self.sock.send.fileno.return_value = 7 + + transport = self.socket_transport() + transport._high_water = 1 + transport.writelines([data]) + self.assertTrue(self.protocol.pause_writing.called) + self.assertTrue(self.sock.send.called) + self.assertTrue(self.loop.writers) + + @unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg') + def test_write_sendmsg_full(self): + data = memoryview(b'data') + self.sock.sendmsg = mock.Mock() + self.sock.sendmsg.return_value = len(data) + + transport = self.socket_transport(sendmsg=True) + transport._buffer.append(data) + self.loop._add_writer(7, transport._write_ready) + transport._write_ready() + self.assertTrue(self.sock.sendmsg.called) + self.assertFalse(self.loop.writers) + + @unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg') + def test_write_sendmsg_partial(self): + + data = memoryview(b'data') + self.sock.sendmsg = mock.Mock() + # Sent partial data + self.sock.sendmsg.return_value = 2 + + transport = self.socket_transport(sendmsg=True) + transport._buffer.append(data) + self.loop._add_writer(7, transport._write_ready) + transport._write_ready() + self.assertTrue(self.sock.sendmsg.called) + self.assertTrue(self.loop.writers) + self.assertEqual(list_to_buffer([b'ta']), transport._buffer) + + @unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg') + def test_write_sendmsg_half_buffer(self): + data = [memoryview(b'data1'), memoryview(b'data2')] + self.sock.sendmsg = mock.Mock() + # Sent partial data + self.sock.sendmsg.return_value = 2 + + transport = self.socket_transport(sendmsg=True) + transport._buffer.extend(data) + self.loop._add_writer(7, transport._write_ready) + transport._write_ready() + self.assertTrue(self.sock.sendmsg.called) + self.assertTrue(self.loop.writers) + self.assertEqual(list_to_buffer([b'ta1', b'data2']), transport._buffer) + + @unittest.skipUnless(selector_events._HAS_SENDMSG, 'no sendmsg') + def test_write_sendmsg_OSError(self): + data = memoryview(b'data') + self.sock.sendmsg = mock.Mock() + err = self.sock.sendmsg.side_effect = OSError() + + transport = self.socket_transport(sendmsg=True) + transport._fatal_error = mock.Mock() + transport._buffer.extend(data) + # Calls _fatal_error and clears the buffer + transport._write_ready() + self.assertTrue(self.sock.sendmsg.called) + self.assertFalse(self.loop.writers) + self.assertEqual(list_to_buffer([]), transport._buffer) + transport._fatal_error.assert_called_with( + err, + 'Fatal write error on socket transport') + + @mock.patch('asyncio.selector_events.logger') + def test_write_exception(self, m_log): + err = self.sock.send.side_effect = OSError() + + data = b'data' + transport = self.socket_transport() + transport._fatal_error = mock.Mock() + transport.write(data) + transport._fatal_error.assert_called_with( + err, + 'Fatal write error on socket transport') + transport._conn_lost = 1 + + self.sock.reset_mock() + transport.write(data) + self.assertFalse(self.sock.send.called) + self.assertEqual(transport._conn_lost, 2) + transport.write(data) + transport.write(data) + transport.write(data) + transport.write(data) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_write_str(self): + transport = self.socket_transport() + self.assertRaises(TypeError, transport.write, 'str') + + def test_write_closing(self): + transport = self.socket_transport() + transport.close() + self.assertEqual(transport._conn_lost, 1) + transport.write(b'data') + self.assertEqual(transport._conn_lost, 2) + + def test_write_ready(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = self.socket_transport() + transport._buffer.append(data) + self.loop._add_writer(7, transport._write_ready) + transport._write_ready() + self.assertTrue(self.sock.send.called) + self.assertFalse(self.loop.writers) + + def test_write_ready_closing(self): + data = memoryview(b'data') + self.sock.send.return_value = len(data) + + transport = self.socket_transport() + transport._closing = True + transport._buffer.append(data) + self.loop._add_writer(7, transport._write_ready) + transport._write_ready() + self.assertTrue(self.sock.send.called) + self.assertFalse(self.loop.writers) + self.sock.close.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + @unittest.skipIf(sys.flags.optimize, "Assertions are disabled in optimized mode") + def test_write_ready_no_data(self): + transport = self.socket_transport() + # This is an internal error. + self.assertRaises(AssertionError, transport._write_ready) + + def test_write_ready_partial(self): + data = memoryview(b'data') + self.sock.send.return_value = 2 + + transport = self.socket_transport() + transport._buffer.append(data) + self.loop._add_writer(7, transport._write_ready) + transport._write_ready() + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(list_to_buffer([b'ta']), transport._buffer) + + def test_write_ready_partial_none(self): + data = b'data' + self.sock.send.return_value = 0 + + transport = self.socket_transport() + transport._buffer.append(data) + self.loop._add_writer(7, transport._write_ready) + transport._write_ready() + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) + + def test_write_ready_tryagain(self): + self.sock.send.side_effect = BlockingIOError + + transport = self.socket_transport() + buffer = list_to_buffer([b'data1', b'data2']) + transport._buffer = buffer + self.loop._add_writer(7, transport._write_ready) + transport._write_ready() + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(buffer, transport._buffer) + + def test_write_ready_exception(self): + err = self.sock.send.side_effect = OSError() + + transport = self.socket_transport() + transport._fatal_error = mock.Mock() + transport._buffer.extend(b'data') + transport._write_ready() + transport._fatal_error.assert_called_with( + err, + 'Fatal write error on socket transport') + + def test_write_eof(self): + tr = self.socket_transport() + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.write_eof() + self.assertEqual(self.sock.shutdown.call_count, 1) + tr.close() + + def test_write_eof_buffer(self): + tr = self.socket_transport() + self.sock.send.side_effect = BlockingIOError + tr.write(b'data') + tr.write_eof() + self.assertEqual(tr._buffer, list_to_buffer([b'data'])) + self.assertTrue(tr._eof) + self.assertFalse(self.sock.shutdown.called) + self.sock.send.side_effect = lambda _: 4 + tr._write_ready() + self.assertTrue(self.sock.send.called) + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.close() + + def test_write_eof_after_close(self): + tr = self.socket_transport() + tr.close() + self.loop.run_until_complete(asyncio.sleep(0)) + tr.write_eof() + + @mock.patch('asyncio.base_events.logger') + def test_transport_close_remove_writer(self, m_log): + remove_writer = self.loop._remove_writer = mock.Mock() + + transport = self.socket_transport() + transport.close() + remove_writer.assert_called_with(self.sock_fd) + + # TODO: RUSTPYTHON + # AssertionError: 2 != 0 + @unittest.expectedFailure + def test_write_buffer_after_close(self): + # gh-115514: If the transport is closed while: + # * Transport write buffer is not empty + # * Transport is paused + # * Protocol has data in its buffer, like SSLProtocol in self._outgoing + # The data is still written out. + + # Also tested with real SSL transport in + # test.test_asyncio.test_ssl.TestSSL.test_remote_shutdown_receives_trailing_data + + data = memoryview(b'data') + self.sock.send.return_value = 2 + self.sock.send.fileno.return_value = 7 + + def _resume_writing(): + transport.write(b"data") + self.protocol.resume_writing.side_effect = None + + self.protocol.resume_writing.side_effect = _resume_writing + + transport = self.socket_transport() + transport._high_water = 1 + + transport.write(data) + + self.assertTrue(transport._protocol_paused) + self.assertTrue(self.sock.send.called) + self.loop.assert_writer(7, transport._write_ready) + + transport.close() + + # not called, we still have data in write buffer + self.assertFalse(self.protocol.connection_lost.called) + + self.loop.writers[7]._run() + # during this ^ run, the _resume_writing mock above was called and added more data + + self.assertEqual(transport.get_write_buffer_size(), 2) + self.loop.writers[7]._run() + + self.assertEqual(transport.get_write_buffer_size(), 0) + self.assertTrue(self.protocol.connection_lost.called) + +class SelectorSocketTransportBufferedProtocolTests(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = self.new_test_loop() + + self.protocol = test_utils.make_test_protocol(asyncio.BufferedProtocol) + self.buf = bytearray(1) + self.protocol.get_buffer.side_effect = lambda hint: self.buf + + self.sock = mock.Mock(socket.socket) + self.sock_fd = self.sock.fileno.return_value = 7 + + def socket_transport(self, waiter=None): + transport = _SelectorSocketTransport(self.loop, self.sock, + self.protocol, waiter=waiter) + self.addCleanup(close_transport, transport) + return transport + + def test_ctor(self): + waiter = self.loop.create_future() + tr = self.socket_transport(waiter=waiter) + self.loop.run_until_complete(waiter) + + self.loop.assert_reader(7, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_get_buffer_error(self): + transport = self.socket_transport() + transport._fatal_error = mock.Mock() + + self.loop.call_exception_handler = mock.Mock() + self.protocol.get_buffer.side_effect = LookupError() + + transport._read_ready() + + self.assertTrue(transport._fatal_error.called) + self.assertTrue(self.protocol.get_buffer.called) + self.assertFalse(self.protocol.buffer_updated.called) + + def test_get_buffer_zerosized(self): + transport = self.socket_transport() + transport._fatal_error = mock.Mock() + + self.loop.call_exception_handler = mock.Mock() + self.protocol.get_buffer.side_effect = lambda hint: bytearray(0) + + transport._read_ready() + + self.assertTrue(transport._fatal_error.called) + self.assertTrue(self.protocol.get_buffer.called) + self.assertFalse(self.protocol.buffer_updated.called) + + def test_proto_type_switch(self): + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + transport = self.socket_transport() + + self.sock.recv.return_value = b'data' + transport._read_ready() + + self.protocol.data_received.assert_called_with(b'data') + + # switch protocol to a BufferedProtocol + + buf_proto = test_utils.make_test_protocol(asyncio.BufferedProtocol) + buf = bytearray(4) + buf_proto.get_buffer.side_effect = lambda hint: buf + + transport.set_protocol(buf_proto) + + self.sock.recv_into.return_value = 10 + transport._read_ready() + + buf_proto.get_buffer.assert_called_with(-1) + buf_proto.buffer_updated.assert_called_with(10) + + def test_buffer_updated_error(self): + transport = self.socket_transport() + transport._fatal_error = mock.Mock() + + self.loop.call_exception_handler = mock.Mock() + self.protocol.buffer_updated.side_effect = LookupError() + + self.sock.recv_into.return_value = 10 + transport._read_ready() + + self.assertTrue(transport._fatal_error.called) + self.assertTrue(self.protocol.get_buffer.called) + self.assertTrue(self.protocol.buffer_updated.called) + + def test_read_eof_received_error(self): + transport = self.socket_transport() + transport.close = mock.Mock() + transport._fatal_error = mock.Mock() + + self.loop.call_exception_handler = mock.Mock() + + self.protocol.eof_received.side_effect = LookupError() + + self.sock.recv_into.return_value = 0 + transport._read_ready() + + self.protocol.eof_received.assert_called_with() + self.assertTrue(transport._fatal_error.called) + + def test_read_ready(self): + transport = self.socket_transport() + + self.sock.recv_into.return_value = 10 + transport._read_ready() + + self.protocol.get_buffer.assert_called_with(-1) + self.protocol.buffer_updated.assert_called_with(10) + + def test_read_ready_eof(self): + transport = self.socket_transport() + transport.close = mock.Mock() + + self.sock.recv_into.return_value = 0 + transport._read_ready() + + self.protocol.eof_received.assert_called_with() + transport.close.assert_called_with() + + def test_read_ready_eof_keep_open(self): + transport = self.socket_transport() + transport.close = mock.Mock() + + self.sock.recv_into.return_value = 0 + self.protocol.eof_received.return_value = True + transport._read_ready() + + self.protocol.eof_received.assert_called_with() + self.assertFalse(transport.close.called) + + @mock.patch('logging.exception') + def test_read_ready_tryagain(self, m_exc): + self.sock.recv_into.side_effect = BlockingIOError + + transport = self.socket_transport() + transport._fatal_error = mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + @mock.patch('logging.exception') + def test_read_ready_tryagain_interrupted(self, m_exc): + self.sock.recv_into.side_effect = InterruptedError + + transport = self.socket_transport() + transport._fatal_error = mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + @mock.patch('logging.exception') + def test_read_ready_conn_reset(self, m_exc): + err = self.sock.recv_into.side_effect = ConnectionResetError() + + transport = self.socket_transport() + transport._force_close = mock.Mock() + with test_utils.disable_logger(): + transport._read_ready() + transport._force_close.assert_called_with(err) + + @mock.patch('logging.exception') + def test_read_ready_err(self, m_exc): + err = self.sock.recv_into.side_effect = OSError() + + transport = self.socket_transport() + transport._fatal_error = mock.Mock() + transport._read_ready() + + transport._fatal_error.assert_called_with( + err, + 'Fatal read error on socket transport') + + +class SelectorDatagramTransportTests(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = self.new_test_loop() + self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol) + self.sock = mock.Mock(spec_set=socket.socket) + self.sock.fileno.return_value = 7 + + def datagram_transport(self, address=None): + self.sock.getpeername.side_effect = None if address else OSError + transport = _SelectorDatagramTransport(self.loop, self.sock, + self.protocol, + address=address) + self.addCleanup(close_transport, transport) + return transport + + def test_read_ready(self): + transport = self.datagram_transport() + + self.sock.recvfrom.return_value = (b'data', ('0.0.0.0', 1234)) + transport._read_ready() + + self.protocol.datagram_received.assert_called_with( + b'data', ('0.0.0.0', 1234)) + + def test_transport_inheritance(self): + transport = self.datagram_transport() + self.assertIsInstance(transport, asyncio.DatagramTransport) + + def test_read_ready_tryagain(self): + transport = self.datagram_transport() + + self.sock.recvfrom.side_effect = BlockingIOError + transport._fatal_error = mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + def test_read_ready_err(self): + transport = self.datagram_transport() + + err = self.sock.recvfrom.side_effect = RuntimeError() + transport._fatal_error = mock.Mock() + transport._read_ready() + + transport._fatal_error.assert_called_with( + err, + 'Fatal read error on datagram transport') + + def test_read_ready_oserr(self): + transport = self.datagram_transport() + + err = self.sock.recvfrom.side_effect = OSError() + transport._fatal_error = mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + self.protocol.error_received.assert_called_with(err) + + def test_sendto(self): + data = b'data' + transport = self.datagram_transport() + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234))) + + def test_sendto_bytearray(self): + data = bytearray(b'data') + transport = self.datagram_transport() + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234))) + + def test_sendto_memoryview(self): + data = memoryview(b'data') + transport = self.datagram_transport() + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234))) + + # TODO: RUSTPYTHON + # AssertionError: False is not true + @unittest.expectedFailure + def test_sendto_no_data(self): + transport = self.datagram_transport() + transport.sendto(b'', ('0.0.0.0', 1234)) + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (b'', ('0.0.0.0', 1234))) + + def test_sendto_buffer(self): + transport = self.datagram_transport() + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport.sendto(b'data2', ('0.0.0.0', 12345)) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + + def test_sendto_buffer_bytearray(self): + data2 = bytearray(b'data2') + transport = self.datagram_transport() + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport.sendto(data2, ('0.0.0.0', 12345)) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + self.assertIsInstance(transport._buffer[1][0], bytes) + + def test_sendto_buffer_memoryview(self): + data2 = memoryview(b'data2') + transport = self.datagram_transport() + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport.sendto(data2, ('0.0.0.0', 12345)) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + self.assertIsInstance(transport._buffer[1][0], bytes) + + # TODO: RUSTPYTHON + # AssertionError: Lists differ: [(b'data1', ('0.0.0.0', 12345)), (b'', ('0.0.0.0', 12345))] != [(b'data1', ('0.0.0.0', 12345))] + @unittest.expectedFailure + def test_sendto_buffer_nodata(self): + data2 = b'' + transport = self.datagram_transport() + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport.sendto(data2, ('0.0.0.0', 12345)) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'', ('0.0.0.0', 12345))], + list(transport._buffer)) + self.assertIsInstance(transport._buffer[1][0], bytes) + + def test_sendto_tryagain(self): + data = b'data' + + self.sock.sendto.side_effect = BlockingIOError + + transport = self.datagram_transport() + transport.sendto(data, ('0.0.0.0', 12345)) + + self.loop.assert_writer(7, transport._sendto_ready) + self.assertEqual( + [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) + + @mock.patch('asyncio.selector_events.logger') + def test_sendto_exception(self, m_log): + data = b'data' + err = self.sock.sendto.side_effect = RuntimeError() + + transport = self.datagram_transport() + transport._fatal_error = mock.Mock() + transport.sendto(data, ()) + + self.assertTrue(transport._fatal_error.called) + transport._fatal_error.assert_called_with( + err, + 'Fatal write error on datagram transport') + transport._conn_lost = 1 + + transport._address = ('123',) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_sendto_error_received(self): + data = b'data' + + self.sock.sendto.side_effect = ConnectionRefusedError + + transport = self.datagram_transport() + transport._fatal_error = mock.Mock() + transport.sendto(data, ()) + + self.assertEqual(transport._conn_lost, 0) + self.assertFalse(transport._fatal_error.called) + + def test_sendto_error_received_connected(self): + data = b'data' + + self.sock.send.side_effect = ConnectionRefusedError + + transport = self.datagram_transport(address=('0.0.0.0', 1)) + transport._fatal_error = mock.Mock() + transport.sendto(data) + + self.assertFalse(transport._fatal_error.called) + self.assertTrue(self.protocol.error_received.called) + + def test_sendto_str(self): + transport = self.datagram_transport() + self.assertRaises(TypeError, transport.sendto, 'str', ()) + + def test_sendto_connected_addr(self): + transport = self.datagram_transport(address=('0.0.0.0', 1)) + self.assertRaises( + ValueError, transport.sendto, b'str', ('0.0.0.0', 2)) + + def test_sendto_closing(self): + transport = self.datagram_transport(address=(1,)) + transport.close() + self.assertEqual(transport._conn_lost, 1) + transport.sendto(b'data', (1,)) + self.assertEqual(transport._conn_lost, 2) + + @unittest.skip('TODO: RUSTPYTHON') + # '_SelectorDatagramTransport' object has no attribute '_header_size' + def test_sendto_sendto_ready(self): + data = b'data' + + # First queue up the buffer by having the socket blocked + self.sock.sendto.side_effect = BlockingIOError + transport = self.datagram_transport() + transport.sendto(data, ('0.0.0.0', 12345)) + self.loop.assert_writer(7, transport._sendto_ready) + self.assertEqual(1, len(transport._buffer)) + self.assertEqual(transport._buffer_size, len(data) + transport._header_size) + + # Now let the socket send the buffer + self.sock.sendto.side_effect = None + transport._sendto_ready() + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 12345))) + self.assertFalse(self.loop.writers) + self.assertFalse(transport._buffer) + self.assertEqual(transport._buffer_size, 0) + + @unittest.skip('TODO: RUSTPYTHON') + # '_SelectorDatagramTransport' object has no attribute '_header_size' + def test_sendto_sendto_ready_blocked(self): + data = b'data' + + # First queue up the buffer by having the socket blocked + self.sock.sendto.side_effect = BlockingIOError + transport = self.datagram_transport() + transport.sendto(data, ('0.0.0.0', 12345)) + self.loop.assert_writer(7, transport._sendto_ready) + self.assertEqual(1, len(transport._buffer)) + self.assertEqual(transport._buffer_size, len(data) + transport._header_size) + + # Now try to send the buffer, it will be added to buffer again if it fails + transport._sendto_ready() + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 12345))) + self.assertTrue(self.loop.writers) + self.assertEqual(1, len(transport._buffer)) + self.assertEqual(transport._buffer_size, len(data) + transport._header_size) + + def test_sendto_ready(self): + data = b'data' + self.sock.sendto.return_value = len(data) + + transport = self.datagram_transport() + transport._buffer.append((data, ('0.0.0.0', 12345))) + self.loop._add_writer(7, transport._sendto_ready) + transport._sendto_ready() + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 12345))) + self.assertFalse(self.loop.writers) + + def test_sendto_ready_closing(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = self.datagram_transport() + transport._closing = True + transport._buffer.append((data, ())) + self.loop._add_writer(7, transport._sendto_ready) + transport._sendto_ready() + self.sock.sendto.assert_called_with(data, ()) + self.assertFalse(self.loop.writers) + self.sock.close.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + def test_sendto_ready_no_data(self): + transport = self.datagram_transport() + self.loop._add_writer(7, transport._sendto_ready) + transport._sendto_ready() + self.assertFalse(self.sock.sendto.called) + self.assertFalse(self.loop.writers) + + def test_sendto_ready_tryagain(self): + self.sock.sendto.side_effect = BlockingIOError + + transport = self.datagram_transport() + transport._buffer.extend([(b'data1', ()), (b'data2', ())]) + self.loop._add_writer(7, transport._sendto_ready) + transport._sendto_ready() + + self.loop.assert_writer(7, transport._sendto_ready) + self.assertEqual( + [(b'data1', ()), (b'data2', ())], + list(transport._buffer)) + + def test_sendto_ready_exception(self): + err = self.sock.sendto.side_effect = RuntimeError() + + transport = self.datagram_transport() + transport._fatal_error = mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + transport._fatal_error.assert_called_with( + err, + 'Fatal write error on datagram transport') + + def test_sendto_ready_error_received(self): + self.sock.sendto.side_effect = ConnectionRefusedError + + transport = self.datagram_transport() + transport._fatal_error = mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + self.assertFalse(transport._fatal_error.called) + + def test_sendto_ready_error_received_connection(self): + self.sock.send.side_effect = ConnectionRefusedError + + transport = self.datagram_transport(address=('0.0.0.0', 1)) + transport._fatal_error = mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + self.assertFalse(transport._fatal_error.called) + self.assertTrue(self.protocol.error_received.called) + + @mock.patch('asyncio.base_events.logger.error') + def test_fatal_error_connected(self, m_exc): + transport = self.datagram_transport(address=('0.0.0.0', 1)) + err = ConnectionRefusedError() + transport._fatal_error(err) + self.assertFalse(self.protocol.error_received.called) + m_exc.assert_not_called() + + @mock.patch('asyncio.base_events.logger.error') + def test_fatal_error_connected_custom_error(self, m_exc): + class MyException(Exception): + pass + transport = self.datagram_transport(address=('0.0.0.0', 1)) + err = MyException() + transport._fatal_error(err) + self.assertFalse(self.protocol.error_received.called) + m_exc.assert_called_with( + test_utils.MockPattern( + 'Fatal error on transport\nprotocol:.*\ntransport:.*'), + exc_info=(MyException, MOCK_ANY, MOCK_ANY)) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_sendfile.py b/Lib/test/test_asyncio/test_sendfile.py new file mode 100644 index 00000000000..2509d4382cd --- /dev/null +++ b/Lib/test/test_asyncio/test_sendfile.py @@ -0,0 +1,585 @@ +"""Tests for sendfile functionality.""" + +import asyncio +import errno +import os +import socket +import sys +import tempfile +import unittest +from asyncio import base_events +from asyncio import constants +from unittest import mock +from test import support +from test.support import os_helper +from test.support import socket_helper +from test.test_asyncio import utils as test_utils + +try: + import ssl +except ImportError: + ssl = None + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class MySendfileProto(asyncio.Protocol): + + def __init__(self, loop=None, close_after=0): + self.transport = None + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.connected = loop.create_future() + self.done = loop.create_future() + self.data = bytearray() + self.close_after = close_after + + def _assert_state(self, *expected): + if self.state not in expected: + raise AssertionError(f'state: {self.state!r}, expected: {expected!r}') + + def connection_made(self, transport): + self.transport = transport + self._assert_state('INITIAL') + self.state = 'CONNECTED' + if self.connected: + self.connected.set_result(None) + + def eof_received(self): + self._assert_state('CONNECTED') + self.state = 'EOF' + + def connection_lost(self, exc): + self._assert_state('CONNECTED', 'EOF') + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + def data_received(self, data): + self._assert_state('CONNECTED') + self.nbytes += len(data) + self.data.extend(data) + super().data_received(data) + if self.close_after and self.nbytes >= self.close_after: + self.transport.close() + + +class MyProto(asyncio.Protocol): + + def __init__(self, loop): + self.started = False + self.closed = False + self.data = bytearray() + self.fut = loop.create_future() + self.transport = None + + def connection_made(self, transport): + self.started = True + self.transport = transport + + def data_received(self, data): + self.data.extend(data) + + def connection_lost(self, exc): + self.closed = True + self.fut.set_result(None) + + async def wait_closed(self): + await self.fut + + +class SendfileBase: + + # Linux >= 6.10 seems buffering up to 17 pages of data. + # So DATA should be large enough to make this test reliable even with a + # 64 KiB page configuration. + DATA = b"x" * (1024 * 17 * 64 + 1) + # Reduce socket buffer size to test on relative small data sets. + BUF_SIZE = 4 * 1024 # 4 KiB + + def create_event_loop(self): + raise NotImplementedError + + @classmethod + def setUpClass(cls): + with open(os_helper.TESTFN, 'wb') as fp: + fp.write(cls.DATA) + super().setUpClass() + + @classmethod + def tearDownClass(cls): + os_helper.unlink(os_helper.TESTFN) + super().tearDownClass() + + def setUp(self): + self.file = open(os_helper.TESTFN, 'rb') + self.addCleanup(self.file.close) + self.loop = self.create_event_loop() + self.set_event_loop(self.loop) + super().setUp() + + def tearDown(self): + # just in case if we have transport close callbacks + if not self.loop.is_closed(): + test_utils.run_briefly(self.loop) + + self.doCleanups() + support.gc_collect() + super().tearDown() + + def run_loop(self, coro): + return self.loop.run_until_complete(coro) + + +class SockSendfileMixin(SendfileBase): + + @classmethod + def setUpClass(cls): + cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE + constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16 + super().setUpClass() + + @classmethod + def tearDownClass(cls): + constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize + super().tearDownClass() + + def make_socket(self, cleanup=True): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setblocking(False) + if cleanup: + self.addCleanup(sock.close) + return sock + + def reduce_receive_buffer_size(self, sock): + # Reduce receive socket buffer size to test on relative + # small data sets. + sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE) + + def reduce_send_buffer_size(self, sock, transport=None): + # Reduce send socket buffer size to test on relative small data sets. + + # On macOS, SO_SNDBUF is reset by connect(). So this method + # should be called after the socket is connected. + sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE) + + if transport is not None: + transport.set_write_buffer_limits(high=self.BUF_SIZE) + + def prepare_socksendfile(self): + proto = MyProto(self.loop) + port = socket_helper.find_unused_port() + srv_sock = self.make_socket(cleanup=False) + srv_sock.bind((socket_helper.HOST, port)) + server = self.run_loop(self.loop.create_server( + lambda: proto, sock=srv_sock)) + self.reduce_receive_buffer_size(srv_sock) + + sock = self.make_socket() + self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port))) + self.reduce_send_buffer_size(sock) + + def cleanup(): + if proto.transport is not None: + # can be None if the task was cancelled before + # connection_made callback + proto.transport.close() + self.run_loop(proto.wait_closed()) + + server.close() + self.run_loop(server.wait_closed()) + + self.addCleanup(cleanup) + + return sock, proto + + def test_sock_sendfile_success(self): + sock, proto = self.prepare_socksendfile() + ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) + sock.close() + self.run_loop(proto.wait_closed()) + + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sock_sendfile_with_offset_and_count(self): + sock, proto = self.prepare_socksendfile() + ret = self.run_loop(self.loop.sock_sendfile(sock, self.file, + 1000, 2000)) + sock.close() + self.run_loop(proto.wait_closed()) + + self.assertEqual(proto.data, self.DATA[1000:3000]) + self.assertEqual(self.file.tell(), 3000) + self.assertEqual(ret, 2000) + + def test_sock_sendfile_zero_size(self): + sock, proto = self.prepare_socksendfile() + with tempfile.TemporaryFile() as f: + ret = self.run_loop(self.loop.sock_sendfile(sock, f, + 0, None)) + sock.close() + self.run_loop(proto.wait_closed()) + + self.assertEqual(ret, 0) + self.assertEqual(self.file.tell(), 0) + + def test_sock_sendfile_mix_with_regular_send(self): + buf = b"mix_regular_send" * (4 * 1024) # 64 KiB + sock, proto = self.prepare_socksendfile() + self.run_loop(self.loop.sock_sendall(sock, buf)) + ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) + self.run_loop(self.loop.sock_sendall(sock, buf)) + sock.close() + self.run_loop(proto.wait_closed()) + + self.assertEqual(ret, len(self.DATA)) + expected = buf + self.DATA + buf + self.assertEqual(proto.data, expected) + self.assertEqual(self.file.tell(), len(self.DATA)) + + +class SendfileMixin(SendfileBase): + + # Note: sendfile via SSL transport is equal to sendfile fallback + + def prepare_sendfile(self, *, is_ssl=False, close_after=0): + port = socket_helper.find_unused_port() + srv_proto = MySendfileProto(loop=self.loop, + close_after=close_after) + if is_ssl: + if not ssl: + self.skipTest("No ssl module") + srv_ctx = test_utils.simple_server_sslcontext() + cli_ctx = test_utils.simple_client_sslcontext() + else: + srv_ctx = None + cli_ctx = None + srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + srv_sock.bind((socket_helper.HOST, port)) + server = self.run_loop(self.loop.create_server( + lambda: srv_proto, sock=srv_sock, ssl=srv_ctx)) + self.reduce_receive_buffer_size(srv_sock) + + if is_ssl: + server_hostname = socket_helper.HOST + else: + server_hostname = None + cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + cli_sock.connect((socket_helper.HOST, port)) + + cli_proto = MySendfileProto(loop=self.loop) + tr, pr = self.run_loop(self.loop.create_connection( + lambda: cli_proto, sock=cli_sock, + ssl=cli_ctx, server_hostname=server_hostname)) + self.reduce_send_buffer_size(cli_sock, transport=tr) + + def cleanup(): + srv_proto.transport.close() + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.run_loop(cli_proto.done) + + server.close() + self.run_loop(server.wait_closed()) + + self.addCleanup(cleanup) + return srv_proto, cli_proto + + @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported") + def test_sendfile_not_supported(self): + tr, pr = self.run_loop( + self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, + family=socket.AF_INET)) + try: + with self.assertRaisesRegex(RuntimeError, "not supported"): + self.run_loop( + self.loop.sendfile(tr, self.file)) + self.assertEqual(0, self.file.tell()) + finally: + # don't use self.addCleanup because it produces resource warning + tr.close() + + def test_sendfile(self): + srv_proto, cli_proto = self.prepare_sendfile() + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_force_fallback(self): + srv_proto, cli_proto = self.prepare_sendfile() + + def sendfile_native(transp, file, offset, count): + # to raise SendfileNotAvailableError + return base_events.BaseEventLoop._sendfile_native( + self.loop, transp, file, offset, count) + + self.loop._sendfile_native = sendfile_native + + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_force_unsupported_native(self): + if sys.platform == 'win32': + if isinstance(self.loop, asyncio.ProactorEventLoop): + self.skipTest("Fails on proactor event loop") + srv_proto, cli_proto = self.prepare_sendfile() + + def sendfile_native(transp, file, offset, count): + # to raise SendfileNotAvailableError + return base_events.BaseEventLoop._sendfile_native( + self.loop, transp, file, offset, count) + + self.loop._sendfile_native = sendfile_native + + with self.assertRaisesRegex(asyncio.SendfileNotAvailableError, + "not supported"): + self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file, + fallback=False)) + + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(srv_proto.nbytes, 0) + self.assertEqual(self.file.tell(), 0) + + def test_sendfile_ssl(self): + srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_for_closing_transp(self): + srv_proto, cli_proto = self.prepare_sendfile() + cli_proto.transport.close() + with self.assertRaisesRegex(RuntimeError, "is closing"): + self.run_loop(self.loop.sendfile(cli_proto.transport, self.file)) + self.run_loop(srv_proto.done) + self.assertEqual(srv_proto.nbytes, 0) + self.assertEqual(self.file.tell(), 0) + + def test_sendfile_pre_and_post_data(self): + srv_proto, cli_proto = self.prepare_sendfile() + PREFIX = b'PREFIX__' * 1024 # 8 KiB + SUFFIX = b'--SUFFIX' * 1024 # 8 KiB + cli_proto.transport.write(PREFIX) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.write(SUFFIX) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_ssl_pre_and_post_data(self): + srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) + PREFIX = b'zxcvbnm' * 1024 + SUFFIX = b'0987654321' * 1024 + cli_proto.transport.write(PREFIX) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.write(SUFFIX) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_partial(self): + srv_proto, cli_proto = self.prepare_sendfile() + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, 100) + self.assertEqual(srv_proto.nbytes, 100) + self.assertEqual(srv_proto.data, self.DATA[1000:1100]) + self.assertEqual(self.file.tell(), 1100) + + def test_sendfile_ssl_partial(self): + srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, 100) + self.assertEqual(srv_proto.nbytes, 100) + self.assertEqual(srv_proto.data, self.DATA[1000:1100]) + self.assertEqual(self.file.tell(), 1100) + + def test_sendfile_close_peer_after_receiving(self): + srv_proto, cli_proto = self.prepare_sendfile( + close_after=len(self.DATA)) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_ssl_close_peer_after_receiving(self): + srv_proto, cli_proto = self.prepare_sendfile( + is_ssl=True, close_after=len(self.DATA)) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + # On Solaris, lowering SO_RCVBUF on a TCP connection after it has been + # established has no effect. Due to its age, this bug affects both Oracle + # Solaris as well as all other OpenSolaris forks (unless they fixed it + # themselves). + @unittest.skipIf(sys.platform.startswith('sunos'), + "Doesn't work on Solaris") + def test_sendfile_close_peer_in_the_middle_of_receiving(self): + srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) + with self.assertRaises(ConnectionError): + self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + self.run_loop(srv_proto.done) + + self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), + srv_proto.nbytes) + if not (sys.platform == 'win32' + and isinstance(self.loop, asyncio.ProactorEventLoop)): + # On Windows, Proactor uses transmitFile, which does not update tell() + self.assertTrue(1024 <= self.file.tell() < len(self.DATA), + self.file.tell()) + self.assertTrue(cli_proto.transport.is_closing()) + + def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self): + + def sendfile_native(transp, file, offset, count): + # to raise SendfileNotAvailableError + return base_events.BaseEventLoop._sendfile_native( + self.loop, transp, file, offset, count) + + self.loop._sendfile_native = sendfile_native + + srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) + with self.assertRaises(ConnectionError): + try: + self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + except OSError as e: + # macOS may raise OSError of EPROTOTYPE when writing to a + # socket that is in the process of closing down. + if e.errno == errno.EPROTOTYPE and sys.platform == "darwin": + raise ConnectionError + else: + raise + + self.run_loop(srv_proto.done) + + self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), + srv_proto.nbytes) + self.assertTrue(1024 <= self.file.tell() < len(self.DATA), + self.file.tell()) + + @unittest.skipIf(not hasattr(os, 'sendfile'), + "Don't have native sendfile support") + def test_sendfile_prevents_bare_write(self): + srv_proto, cli_proto = self.prepare_sendfile() + fut = self.loop.create_future() + + async def coro(): + fut.set_result(None) + return await self.loop.sendfile(cli_proto.transport, self.file) + + t = self.loop.create_task(coro()) + self.run_loop(fut) + with self.assertRaisesRegex(RuntimeError, + "sendfile is in progress"): + cli_proto.transport.write(b'data') + ret = self.run_loop(t) + self.assertEqual(ret, len(self.DATA)) + + def test_sendfile_no_fallback_for_fallback_transport(self): + transport = mock.Mock() + transport.is_closing.side_effect = lambda: False + transport._sendfile_compatible = constants._SendfileMode.FALLBACK + with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'): + self.loop.run_until_complete( + self.loop.sendfile(transport, None, fallback=False)) + + +class SendfileTestsBase(SendfileMixin, SockSendfileMixin): + pass + + +if sys.platform == 'win32': + + class SelectEventLoopTests(SendfileTestsBase, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop() + + class ProactorEventLoopTests(SendfileTestsBase, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.ProactorEventLoop() + +else: + import selectors + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(SendfileTestsBase, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop( + selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(SendfileTestsBase, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(SendfileTestsBase, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(SendfileTestsBase, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.SelectSelector()) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_server.py b/Lib/test/test_asyncio/test_server.py new file mode 100644 index 00000000000..e4232ff0d5d --- /dev/null +++ b/Lib/test/test_asyncio/test_server.py @@ -0,0 +1,364 @@ +import asyncio +import os +import socket +import time +import threading +import unittest + +from test.support import socket_helper +from test.test_asyncio import utils as test_utils +from test.test_asyncio import functional as func_tests + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class BaseStartServer(func_tests.FunctionalTestCaseMixin): + + def new_loop(self): + raise NotImplementedError + + def test_start_server_1(self): + HELLO_MSG = b'1' * 1024 * 5 + b'\n' + + def client(sock, addr): + for i in range(10): + time.sleep(0.2) + if srv.is_serving(): + break + else: + raise RuntimeError + + sock.settimeout(2) + sock.connect(addr) + sock.send(HELLO_MSG) + sock.recv_all(1) + sock.close() + + async def serve(reader, writer): + await reader.readline() + main_task.cancel() + writer.write(b'1') + writer.close() + await writer.wait_closed() + + async def main(srv): + async with srv: + await srv.serve_forever() + + srv = self.loop.run_until_complete(asyncio.start_server( + serve, socket_helper.HOSTv4, 0, start_serving=False)) + + self.assertFalse(srv.is_serving()) + + main_task = self.loop.create_task(main(srv)) + + addr = srv.sockets[0].getsockname() + with self.assertRaises(asyncio.CancelledError): + with self.tcp_client(lambda sock: client(sock, addr)): + self.loop.run_until_complete(main_task) + + self.assertEqual(srv.sockets, ()) + + self.assertIsNone(srv._sockets) + self.assertIsNone(srv._waiters) + self.assertFalse(srv.is_serving()) + + with self.assertRaisesRegex(RuntimeError, r'is closed'): + self.loop.run_until_complete(srv.serve_forever()) + + +class SelectorStartServerTests(BaseStartServer, unittest.TestCase): + + def new_loop(self): + return asyncio.SelectorEventLoop() + + @socket_helper.skip_unless_bind_unix_socket + def test_start_unix_server_1(self): + HELLO_MSG = b'1' * 1024 * 5 + b'\n' + started = threading.Event() + + def client(sock, addr): + sock.settimeout(2) + started.wait(5) + sock.connect(addr) + sock.send(HELLO_MSG) + sock.recv_all(1) + sock.close() + + async def serve(reader, writer): + await reader.readline() + main_task.cancel() + writer.write(b'1') + writer.close() + await writer.wait_closed() + + async def main(srv): + async with srv: + self.assertFalse(srv.is_serving()) + await srv.start_serving() + self.assertTrue(srv.is_serving()) + started.set() + await srv.serve_forever() + + with test_utils.unix_socket_path() as addr: + srv = self.loop.run_until_complete(asyncio.start_unix_server( + serve, addr, start_serving=False)) + + main_task = self.loop.create_task(main(srv)) + + with self.assertRaises(asyncio.CancelledError): + with self.unix_client(lambda sock: client(sock, addr)): + self.loop.run_until_complete(main_task) + + self.assertEqual(srv.sockets, ()) + + self.assertIsNone(srv._sockets) + self.assertIsNone(srv._waiters) + self.assertFalse(srv.is_serving()) + + with self.assertRaisesRegex(RuntimeError, r'is closed'): + self.loop.run_until_complete(srv.serve_forever()) + + +class TestServer2(unittest.IsolatedAsyncioTestCase): + + async def test_wait_closed_basic(self): + async def serve(rd, wr): + try: + await rd.read() + finally: + wr.close() + await wr.wait_closed() + + srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0) + self.addCleanup(srv.close) + + # active count = 0, not closed: should block + task1 = asyncio.create_task(srv.wait_closed()) + await asyncio.sleep(0) + self.assertFalse(task1.done()) + + # active count != 0, not closed: should block + addr = srv.sockets[0].getsockname() + (rd, wr) = await asyncio.open_connection(addr[0], addr[1]) + task2 = asyncio.create_task(srv.wait_closed()) + await asyncio.sleep(0) + self.assertFalse(task1.done()) + self.assertFalse(task2.done()) + + srv.close() + await asyncio.sleep(0) + # active count != 0, closed: should block + task3 = asyncio.create_task(srv.wait_closed()) + await asyncio.sleep(0) + self.assertFalse(task1.done()) + self.assertFalse(task2.done()) + self.assertFalse(task3.done()) + + wr.close() + await wr.wait_closed() + # active count == 0, closed: should unblock + await task1 + await task2 + await task3 + await srv.wait_closed() # Return immediately + + async def test_wait_closed_race(self): + # Test a regression in 3.12.0, should be fixed in 3.12.1 + async def serve(rd, wr): + try: + await rd.read() + finally: + wr.close() + await wr.wait_closed() + + srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0) + self.addCleanup(srv.close) + + task = asyncio.create_task(srv.wait_closed()) + await asyncio.sleep(0) + self.assertFalse(task.done()) + addr = srv.sockets[0].getsockname() + (rd, wr) = await asyncio.open_connection(addr[0], addr[1]) + loop = asyncio.get_running_loop() + loop.call_soon(srv.close) + loop.call_soon(wr.close) + await srv.wait_closed() + + @unittest.skip('TODO: RUSTPYTHON') + # AttributeError: 'Server' object has no attribute 'close_clients' + async def test_close_clients(self): + async def serve(rd, wr): + try: + await rd.read() + finally: + wr.close() + await wr.wait_closed() + + srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0) + self.addCleanup(srv.close) + + addr = srv.sockets[0].getsockname() + (rd, wr) = await asyncio.open_connection(addr[0], addr[1]) + self.addCleanup(wr.close) + + task = asyncio.create_task(srv.wait_closed()) + await asyncio.sleep(0) + self.assertFalse(task.done()) + + srv.close() + srv.close_clients() + await asyncio.sleep(0) + await asyncio.sleep(0) + self.assertTrue(task.done()) + + @unittest.skip('TODO: RUSTPYTHON') + # AttributeError: 'Server' object has no attribute 'abort_clients' + async def test_abort_clients(self): + async def serve(rd, wr): + fut.set_result((rd, wr)) + await wr.wait_closed() + + fut = asyncio.Future() + srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0) + self.addCleanup(srv.close) + + addr = srv.sockets[0].getsockname() + (c_rd, c_wr) = await asyncio.open_connection(addr[0], addr[1], limit=4096) + self.addCleanup(c_wr.close) + + (s_rd, s_wr) = await fut + + # Limit the socket buffers so we can more reliably overfill them + s_sock = s_wr.get_extra_info('socket') + s_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 65536) + c_sock = c_wr.get_extra_info('socket') + c_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 65536) + + # Get the reader in to a paused state by sending more than twice + # the configured limit + s_wr.write(b'a' * 4096) + s_wr.write(b'a' * 4096) + s_wr.write(b'a' * 4096) + while c_wr.transport.is_reading(): + await asyncio.sleep(0) + + # Get the writer in a waiting state by sending data until the + # kernel stops accepting more data in the send buffer. + # gh-122136: getsockopt() does not reliably report the buffer size + # available for message content. + # We loop until we start filling up the asyncio buffer. + # To avoid an infinite loop we cap at 10 times the expected value + c_bufsize = c_sock.getsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF) + s_bufsize = s_sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF) + for i in range(10): + s_wr.write(b'a' * c_bufsize) + s_wr.write(b'a' * s_bufsize) + if s_wr.transport.get_write_buffer_size() > 0: + break + self.assertNotEqual(s_wr.transport.get_write_buffer_size(), 0) + + task = asyncio.create_task(srv.wait_closed()) + await asyncio.sleep(0) + self.assertFalse(task.done()) + + srv.close() + srv.abort_clients() + await asyncio.sleep(0) + await asyncio.sleep(0) + self.assertTrue(task.done()) + + +# Test the various corner cases of Unix server socket removal +class UnixServerCleanupTests(unittest.IsolatedAsyncioTestCase): + @socket_helper.skip_unless_bind_unix_socket + # TODO: RUSTPYTHON + # AssertionError: True is not false + @unittest.expectedFailure + async def test_unix_server_addr_cleanup(self): + # Default scenario + with test_utils.unix_socket_path() as addr: + async def serve(*args): + pass + + srv = await asyncio.start_unix_server(serve, addr) + + srv.close() + self.assertFalse(os.path.exists(addr)) + + @socket_helper.skip_unless_bind_unix_socket + # TODO: RUSTPYTHON + # AssertionError: True is not false + @unittest.expectedFailure + async def test_unix_server_sock_cleanup(self): + # Using already bound socket + with test_utils.unix_socket_path() as addr: + async def serve(*args): + pass + + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: + sock.bind(addr) + + srv = await asyncio.start_unix_server(serve, sock=sock) + + srv.close() + self.assertFalse(os.path.exists(addr)) + + @socket_helper.skip_unless_bind_unix_socket + async def test_unix_server_cleanup_gone(self): + # Someone else has already cleaned up the socket + with test_utils.unix_socket_path() as addr: + async def serve(*args): + pass + + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: + sock.bind(addr) + + srv = await asyncio.start_unix_server(serve, sock=sock) + + os.unlink(addr) + + srv.close() + + @socket_helper.skip_unless_bind_unix_socket + async def test_unix_server_cleanup_replaced(self): + # Someone else has replaced the socket with their own + with test_utils.unix_socket_path() as addr: + async def serve(*args): + pass + + srv = await asyncio.start_unix_server(serve, addr) + + os.unlink(addr) + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: + sock.bind(addr) + + srv.close() + self.assertTrue(os.path.exists(addr)) + + @socket_helper.skip_unless_bind_unix_socket + @unittest.skip('TODO: RUSTPYTHON') + # TypeError: _UnixSelectorEventLoop.create_unix_server() got an unexpected keyword argument 'cleanup_socket' + async def test_unix_server_cleanup_prevented(self): + # Automatic cleanup explicitly disabled + with test_utils.unix_socket_path() as addr: + async def serve(*args): + pass + + srv = await asyncio.start_unix_server(serve, addr, cleanup_socket=False) + + srv.close() + self.assertTrue(os.path.exists(addr)) + + +@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only') +class ProactorStartServerTests(BaseStartServer, unittest.TestCase): + + def new_loop(self): + return asyncio.ProactorEventLoop() + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_sock_lowlevel.py b/Lib/test/test_asyncio/test_sock_lowlevel.py new file mode 100644 index 00000000000..acef24a703b --- /dev/null +++ b/Lib/test/test_asyncio/test_sock_lowlevel.py @@ -0,0 +1,679 @@ +import socket +import asyncio +import sys +import unittest + +from asyncio import proactor_events +from itertools import cycle, islice +from unittest.mock import Mock +from test.test_asyncio import utils as test_utils +from test import support +from test.support import socket_helper + +if socket_helper.tcp_blackhole(): + raise unittest.SkipTest('Not relevant to ProactorEventLoop') + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class MyProto(asyncio.Protocol): + connected = None + done = None + + def __init__(self, loop=None): + self.transport = None + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.connected = loop.create_future() + self.done = loop.create_future() + + def _assert_state(self, *expected): + if self.state not in expected: + raise AssertionError(f'state: {self.state!r}, expected: {expected!r}') + + def connection_made(self, transport): + self.transport = transport + self._assert_state('INITIAL') + self.state = 'CONNECTED' + if self.connected: + self.connected.set_result(None) + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + self._assert_state('CONNECTED') + self.nbytes += len(data) + + def eof_received(self): + self._assert_state('CONNECTED') + self.state = 'EOF' + + def connection_lost(self, exc): + self._assert_state('CONNECTED', 'EOF') + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class BaseSockTestsMixin: + + def create_event_loop(self): + raise NotImplementedError + + def setUp(self): + self.loop = self.create_event_loop() + self.set_event_loop(self.loop) + super().setUp() + + def tearDown(self): + # just in case if we have transport close callbacks + if not self.loop.is_closed(): + test_utils.run_briefly(self.loop) + + self.doCleanups() + support.gc_collect() + super().tearDown() + + def _basetest_sock_client_ops(self, httpd, sock): + if not isinstance(self.loop, proactor_events.BaseProactorEventLoop): + # in debug mode, socket operations must fail + # if the socket is not in blocking mode + self.loop.set_debug(True) + sock.setblocking(True) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_recv_into(sock, bytearray())) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_accept(sock)) + + # test in non-blocking mode + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + # consume data + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + sock.close() + self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + + def _basetest_sock_recv_into(self, httpd, sock): + # same as _basetest_sock_client_ops, but using sock_recv_into + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = bytearray(1024) + with memoryview(data) as buf: + nbytes = self.loop.run_until_complete( + self.loop.sock_recv_into(sock, buf[:1024])) + # consume data + self.loop.run_until_complete( + self.loop.sock_recv_into(sock, buf[nbytes:])) + sock.close() + self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + + def test_sock_client_ops(self): + with test_utils.run_test_server() as httpd: + sock = socket.socket() + self._basetest_sock_client_ops(httpd, sock) + sock = socket.socket() + self._basetest_sock_recv_into(httpd, sock) + + async def _basetest_sock_recv_racing(self, httpd, sock): + sock.setblocking(False) + await self.loop.sock_connect(sock, httpd.address) + + task = asyncio.create_task(self.loop.sock_recv(sock, 1024)) + await asyncio.sleep(0) + task.cancel() + + asyncio.create_task( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = await self.loop.sock_recv(sock, 1024) + # consume data + await self.loop.sock_recv(sock, 1024) + + self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + + async def _basetest_sock_recv_into_racing(self, httpd, sock): + sock.setblocking(False) + await self.loop.sock_connect(sock, httpd.address) + + data = bytearray(1024) + with memoryview(data) as buf: + task = asyncio.create_task( + self.loop.sock_recv_into(sock, buf[:1024])) + await asyncio.sleep(0) + task.cancel() + + task = asyncio.create_task( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + nbytes = await self.loop.sock_recv_into(sock, buf[:1024]) + # consume data + await self.loop.sock_recv_into(sock, buf[nbytes:]) + self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + + await task + + async def _basetest_sock_send_racing(self, listener, sock): + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + + # make connection + sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) + sock.setblocking(False) + task = asyncio.create_task( + self.loop.sock_connect(sock, listener.getsockname())) + await asyncio.sleep(0) + server = listener.accept()[0] + server.setblocking(False) + + with server: + await task + + # fill the buffer until sending 5 chars would block + size = 8192 + while size >= 4: + with self.assertRaises(BlockingIOError): + while True: + sock.send(b' ' * size) + size = int(size / 2) + + # cancel a blocked sock_sendall + task = asyncio.create_task( + self.loop.sock_sendall(sock, b'hello')) + await asyncio.sleep(0) + task.cancel() + + # receive everything that is not a space + async def recv_all(): + rv = b'' + while True: + buf = await self.loop.sock_recv(server, 8192) + if not buf: + return rv + rv += buf.strip() + task = asyncio.create_task(recv_all()) + + # immediately make another sock_sendall call + await self.loop.sock_sendall(sock, b'world') + sock.shutdown(socket.SHUT_WR) + data = await task + # ProactorEventLoop could deliver hello, so endswith is necessary + self.assertTrue(data.endswith(b'world')) + + # After the first connect attempt before the listener is ready, + # the socket needs time to "recover" to make the next connect call. + # On Linux, a second retry will do. On Windows, the waiting time is + # unpredictable; and on FreeBSD the socket may never come back + # because it's a loopback address. Here we'll just retry for a few + # times, and have to skip the test if it's not working. See also: + # https://stackoverflow.com/a/54437602/3316267 + # https://lists.freebsd.org/pipermail/freebsd-current/2005-May/049876.html + async def _basetest_sock_connect_racing(self, listener, sock): + listener.bind(('127.0.0.1', 0)) + addr = listener.getsockname() + sock.setblocking(False) + + task = asyncio.create_task(self.loop.sock_connect(sock, addr)) + await asyncio.sleep(0) + task.cancel() + + listener.listen(1) + + skip_reason = "Max retries reached" + for i in range(128): + try: + await self.loop.sock_connect(sock, addr) + except ConnectionRefusedError as e: + skip_reason = e + except OSError as e: + skip_reason = e + + # Retry only for this error: + # [WinError 10022] An invalid argument was supplied + if getattr(e, 'winerror', 0) != 10022: + break + else: + # success + return + + self.skipTest(skip_reason) + + def test_sock_client_racing(self): + with test_utils.run_test_server() as httpd: + sock = socket.socket() + with sock: + self.loop.run_until_complete(asyncio.wait_for( + self._basetest_sock_recv_racing(httpd, sock), 10)) + sock = socket.socket() + with sock: + self.loop.run_until_complete(asyncio.wait_for( + self._basetest_sock_recv_into_racing(httpd, sock), 10)) + listener = socket.socket() + sock = socket.socket() + with listener, sock: + self.loop.run_until_complete(asyncio.wait_for( + self._basetest_sock_send_racing(listener, sock), 10)) + + def test_sock_client_connect_racing(self): + listener = socket.socket() + sock = socket.socket() + with listener, sock: + self.loop.run_until_complete(asyncio.wait_for( + self._basetest_sock_connect_racing(listener, sock), 10)) + + async def _basetest_huge_content(self, address): + sock = socket.socket() + sock.setblocking(False) + DATA_SIZE = 10_000_00 + + chunk = b'0123456789' * (DATA_SIZE // 10) + + await self.loop.sock_connect(sock, address) + await self.loop.sock_sendall(sock, + (b'POST /loop HTTP/1.0\r\n' + + b'Content-Length: %d\r\n' % DATA_SIZE + + b'\r\n')) + + task = asyncio.create_task(self.loop.sock_sendall(sock, chunk)) + + data = await self.loop.sock_recv(sock, DATA_SIZE) + # HTTP headers size is less than MTU, + # they are sent by the first packet always + self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + while data.find(b'\r\n\r\n') == -1: + data += await self.loop.sock_recv(sock, DATA_SIZE) + # Strip headers + headers = data[:data.index(b'\r\n\r\n') + 4] + data = data[len(headers):] + + size = DATA_SIZE + checker = cycle(b'0123456789') + + expected = bytes(islice(checker, len(data))) + self.assertEqual(data, expected) + size -= len(data) + + while True: + data = await self.loop.sock_recv(sock, DATA_SIZE) + if not data: + break + expected = bytes(islice(checker, len(data))) + self.assertEqual(data, expected) + size -= len(data) + self.assertEqual(size, 0) + + await task + sock.close() + + def test_huge_content(self): + with test_utils.run_test_server() as httpd: + self.loop.run_until_complete( + self._basetest_huge_content(httpd.address)) + + async def _basetest_huge_content_recvinto(self, address): + sock = socket.socket() + sock.setblocking(False) + DATA_SIZE = 10_000_00 + + chunk = b'0123456789' * (DATA_SIZE // 10) + + await self.loop.sock_connect(sock, address) + await self.loop.sock_sendall(sock, + (b'POST /loop HTTP/1.0\r\n' + + b'Content-Length: %d\r\n' % DATA_SIZE + + b'\r\n')) + + task = asyncio.create_task(self.loop.sock_sendall(sock, chunk)) + + array = bytearray(DATA_SIZE) + buf = memoryview(array) + + nbytes = await self.loop.sock_recv_into(sock, buf) + data = bytes(buf[:nbytes]) + # HTTP headers size is less than MTU, + # they are sent by the first packet always + self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + while data.find(b'\r\n\r\n') == -1: + nbytes = await self.loop.sock_recv_into(sock, buf) + data = bytes(buf[:nbytes]) + # Strip headers + headers = data[:data.index(b'\r\n\r\n') + 4] + data = data[len(headers):] + + size = DATA_SIZE + checker = cycle(b'0123456789') + + expected = bytes(islice(checker, len(data))) + self.assertEqual(data, expected) + size -= len(data) + + while True: + nbytes = await self.loop.sock_recv_into(sock, buf) + data = buf[:nbytes] + if not data: + break + expected = bytes(islice(checker, len(data))) + self.assertEqual(data, expected) + size -= len(data) + self.assertEqual(size, 0) + + await task + sock.close() + + def test_huge_content_recvinto(self): + with test_utils.run_test_server() as httpd: + self.loop.run_until_complete( + self._basetest_huge_content_recvinto(httpd.address)) + + async def _basetest_datagram_recvfrom(self, server_address): + # Happy path, sock.sendto() returns immediately + data = b'\x01' * 4096 + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: + sock.setblocking(False) + await self.loop.sock_sendto(sock, data, server_address) + received_data, from_addr = await self.loop.sock_recvfrom( + sock, 4096) + self.assertEqual(received_data, data) + self.assertEqual(from_addr, server_address) + + def test_recvfrom(self): + with test_utils.run_udp_echo_server() as server_address: + self.loop.run_until_complete( + self._basetest_datagram_recvfrom(server_address)) + + async def _basetest_datagram_recvfrom_into(self, server_address): + # Happy path, sock.sendto() returns immediately + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: + sock.setblocking(False) + + buf = bytearray(4096) + data = b'\x01' * 4096 + await self.loop.sock_sendto(sock, data, server_address) + num_bytes, from_addr = await self.loop.sock_recvfrom_into( + sock, buf) + self.assertEqual(num_bytes, 4096) + self.assertEqual(buf, data) + self.assertEqual(from_addr, server_address) + + buf = bytearray(8192) + await self.loop.sock_sendto(sock, data, server_address) + num_bytes, from_addr = await self.loop.sock_recvfrom_into( + sock, buf, 4096) + self.assertEqual(num_bytes, 4096) + self.assertEqual(buf[:4096], data[:4096]) + self.assertEqual(from_addr, server_address) + + def test_recvfrom_into(self): + with test_utils.run_udp_echo_server() as server_address: + self.loop.run_until_complete( + self._basetest_datagram_recvfrom_into(server_address)) + + async def _basetest_datagram_sendto_blocking(self, server_address): + # Sad path, sock.sendto() raises BlockingIOError + # This involves patching sock.sendto() to raise BlockingIOError but + # sendto() is not used by the proactor event loop + data = b'\x01' * 4096 + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: + sock.setblocking(False) + mock_sock = Mock(sock) + mock_sock.gettimeout = sock.gettimeout + mock_sock.sendto.configure_mock(side_effect=BlockingIOError) + mock_sock.fileno = sock.fileno + self.loop.call_soon( + lambda: setattr(mock_sock, 'sendto', sock.sendto) + ) + await self.loop.sock_sendto(mock_sock, data, server_address) + + received_data, from_addr = await self.loop.sock_recvfrom( + sock, 4096) + self.assertEqual(received_data, data) + self.assertEqual(from_addr, server_address) + + def test_sendto_blocking(self): + if sys.platform == 'win32': + if isinstance(self.loop, asyncio.ProactorEventLoop): + raise unittest.SkipTest('Not relevant to ProactorEventLoop') + + with test_utils.run_udp_echo_server() as server_address: + self.loop.run_until_complete( + self._basetest_datagram_sendto_blocking(server_address)) + + @socket_helper.skip_unless_bind_unix_socket + def test_unix_sock_client_ops(self): + with test_utils.run_test_unix_server() as httpd: + sock = socket.socket(socket.AF_UNIX) + self._basetest_sock_client_ops(httpd, sock) + sock = socket.socket(socket.AF_UNIX) + self._basetest_sock_recv_into(httpd, sock) + + def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('127.0.0.1', 0)) + address = s.getsockname() + finally: + s.close() + + sock = socket.socket() + sock.setblocking(False) + with self.assertRaises(ConnectionRefusedError): + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.loop.sock_accept(listener) + conn, addr = self.loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + def test_cancel_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + sockaddr = listener.getsockname() + f = asyncio.wait_for(self.loop.sock_accept(listener), 0.1) + with self.assertRaises(asyncio.TimeoutError): + self.loop.run_until_complete(f) + + listener.close() + client = socket.socket() + client.setblocking(False) + f = self.loop.sock_connect(client, sockaddr) + with self.assertRaises(ConnectionRefusedError): + self.loop.run_until_complete(f) + + client.close() + + def test_create_connection_sock(self): + with test_utils.run_test_server() as httpd: + sock = None + infos = self.loop.run_until_complete( + self.loop.getaddrinfo( + *httpd.address, type=socket.SOCK_STREAM)) + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + except BaseException: + pass + else: + break + else: + self.fail('Can not create socket.') + + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), sock=sock) + tr, pr = self.loop.run_until_complete(f) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + +if sys.platform == 'win32': + + class SelectEventLoopTests(BaseSockTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop() + + + class ProactorEventLoopTests(BaseSockTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.ProactorEventLoop() + + + async def _basetest_datagram_send_to_non_listening_address(self, + recvfrom): + # see: + # https://github.com/python/cpython/issues/91227 + # https://github.com/python/cpython/issues/88906 + # https://bugs.python.org/issue47071 + # https://bugs.python.org/issue44743 + # The Proactor event loop would fail to receive datagram messages + # after sending a message to an address that wasn't listening. + + def create_socket(): + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.setblocking(False) + sock.bind(('127.0.0.1', 0)) + return sock + + socket_1 = create_socket() + addr_1 = socket_1.getsockname() + + socket_2 = create_socket() + addr_2 = socket_2.getsockname() + + # creating and immediately closing this to try to get an address + # that is not listening + socket_3 = create_socket() + addr_3 = socket_3.getsockname() + socket_3.shutdown(socket.SHUT_RDWR) + socket_3.close() + + socket_1_recv_task = self.loop.create_task(recvfrom(socket_1)) + socket_2_recv_task = self.loop.create_task(recvfrom(socket_2)) + await asyncio.sleep(0) + + await self.loop.sock_sendto(socket_1, b'a', addr_2) + self.assertEqual(await socket_2_recv_task, b'a') + + await self.loop.sock_sendto(socket_2, b'b', addr_1) + self.assertEqual(await socket_1_recv_task, b'b') + socket_1_recv_task = self.loop.create_task(recvfrom(socket_1)) + await asyncio.sleep(0) + + # this should send to an address that isn't listening + await self.loop.sock_sendto(socket_1, b'c', addr_3) + self.assertEqual(await socket_1_recv_task, b'') + socket_1_recv_task = self.loop.create_task(recvfrom(socket_1)) + await asyncio.sleep(0) + + # socket 1 should still be able to receive messages after sending + # to an address that wasn't listening + socket_2.sendto(b'd', addr_1) + self.assertEqual(await socket_1_recv_task, b'd') + + socket_1.shutdown(socket.SHUT_RDWR) + socket_1.close() + socket_2.shutdown(socket.SHUT_RDWR) + socket_2.close() + + + def test_datagram_send_to_non_listening_address_recvfrom(self): + async def recvfrom(socket): + data, _ = await self.loop.sock_recvfrom(socket, 4096) + return data + + self.loop.run_until_complete( + self._basetest_datagram_send_to_non_listening_address( + recvfrom)) + + + def test_datagram_send_to_non_listening_address_recvfrom_into(self): + async def recvfrom_into(socket): + buf = bytearray(4096) + length, _ = await self.loop.sock_recvfrom_into(socket, buf, + 4096) + return buf[:length] + + self.loop.run_until_complete( + self._basetest_datagram_send_to_non_listening_address( + recvfrom_into)) + +else: + import selectors + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(BaseSockTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop( + selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(BaseSockTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(BaseSockTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(BaseSockTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.SelectSelector()) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_ssl.py b/Lib/test/test_asyncio/test_ssl.py new file mode 100644 index 00000000000..96d03658a07 --- /dev/null +++ b/Lib/test/test_asyncio/test_ssl.py @@ -0,0 +1,1928 @@ +# Contains code from https://github.com/MagicStack/uvloop/tree/v0.16.0 +# SPDX-License-Identifier: PSF-2.0 AND (MIT OR Apache-2.0) +# SPDX-FileCopyrightText: Copyright (c) 2015-2021 MagicStack Inc. http://magic.io + +import asyncio +import contextlib +import gc +import logging +import select +import socket +import sys +import tempfile +import threading +import time +import unittest.mock +import weakref +import unittest + +try: + import ssl +except ImportError: + ssl = None + +from test import support +from test.test_asyncio import utils as test_utils + + +MACOS = (sys.platform == 'darwin') +BUF_MULTIPLIER = 1024 if not MACOS else 64 + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class MyBaseProto(asyncio.Protocol): + connected = None + done = None + + def __init__(self, loop=None): + self.transport = None + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.connected = asyncio.Future(loop=loop) + self.done = asyncio.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + if self.connected: + self.connected.set_result(None) + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MessageOutFilter(logging.Filter): + def __init__(self, msg): + self.msg = msg + + def filter(self, record): + if self.msg in record.msg: + return False + return True + + +@unittest.skipIf(ssl is None, 'No ssl module') +class TestSSL(test_utils.TestCase): + + PAYLOAD_SIZE = 1024 * 100 + TIMEOUT = support.LONG_TIMEOUT + + def setUp(self): + super().setUp() + self.loop = asyncio.new_event_loop() + self.set_event_loop(self.loop) + self.addCleanup(self.loop.close) + + def tearDown(self): + # just in case if we have transport close callbacks + if not self.loop.is_closed(): + test_utils.run_briefly(self.loop) + + self.doCleanups() + support.gc_collect() + super().tearDown() + + def tcp_server(self, server_prog, *, + family=socket.AF_INET, + addr=None, + timeout=support.SHORT_TIMEOUT, + backlog=1, + max_clients=10): + + if addr is None: + if family == getattr(socket, "AF_UNIX", None): + with tempfile.NamedTemporaryFile() as tmp: + addr = tmp.name + else: + addr = ('127.0.0.1', 0) + + sock = socket.socket(family, socket.SOCK_STREAM) + + if timeout is None: + raise RuntimeError('timeout is required') + if timeout <= 0: + raise RuntimeError('only blocking sockets are supported') + sock.settimeout(timeout) + + try: + sock.bind(addr) + sock.listen(backlog) + except OSError as ex: + sock.close() + raise ex + + return TestThreadedServer( + self, sock, server_prog, timeout, max_clients) + + def tcp_client(self, client_prog, + family=socket.AF_INET, + timeout=support.SHORT_TIMEOUT): + + sock = socket.socket(family, socket.SOCK_STREAM) + + if timeout is None: + raise RuntimeError('timeout is required') + if timeout <= 0: + raise RuntimeError('only blocking sockets are supported') + sock.settimeout(timeout) + + return TestThreadedClient( + self, sock, client_prog, timeout) + + def unix_server(self, *args, **kwargs): + return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs) + + def unix_client(self, *args, **kwargs): + return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs) + + def _create_server_ssl_context(self, certfile, keyfile=None): + sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.load_cert_chain(certfile, keyfile) + return sslcontext + + def _create_client_ssl_context(self, *, disable_verify=True): + sslcontext = ssl.create_default_context() + sslcontext.check_hostname = False + if disable_verify: + sslcontext.verify_mode = ssl.CERT_NONE + return sslcontext + + @contextlib.contextmanager + def _silence_eof_received_warning(self): + # TODO This warning has to be fixed in asyncio. + logger = logging.getLogger('asyncio') + filter = MessageOutFilter('has no effect when using ssl') + logger.addFilter(filter) + try: + yield + finally: + logger.removeFilter(filter) + + def _abort_socket_test(self, ex): + try: + self.loop.stop() + finally: + self.fail(ex) + + def new_loop(self): + return asyncio.new_event_loop() + + def new_policy(self): + return asyncio.DefaultEventLoopPolicy() + + async def wait_closed(self, obj): + if not isinstance(obj, asyncio.StreamWriter): + return + try: + await obj.wait_closed() + except (BrokenPipeError, ConnectionError): + pass + + @support.bigmemtest(size=25, memuse=90*2**20, dry_run=False) + def test_create_server_ssl_1(self, size): + CNT = 0 # number of clients that were successful + TOTAL_CNT = size # total number of clients that test will create + TIMEOUT = support.LONG_TIMEOUT # timeout for this test + + A_DATA = b'A' * 1024 * BUF_MULTIPLIER + B_DATA = b'B' * 1024 * BUF_MULTIPLIER + + sslctx = self._create_server_ssl_context( + test_utils.ONLYCERT, test_utils.ONLYKEY + ) + client_sslctx = self._create_client_ssl_context() + + clients = [] + + async def handle_client(reader, writer): + nonlocal CNT + + data = await reader.readexactly(len(A_DATA)) + self.assertEqual(data, A_DATA) + writer.write(b'OK') + + data = await reader.readexactly(len(B_DATA)) + self.assertEqual(data, B_DATA) + writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) + + await writer.drain() + writer.close() + + CNT += 1 + + async def test_client(addr): + fut = asyncio.Future() + + def prog(sock): + try: + sock.starttls(client_sslctx) + sock.connect(addr) + sock.send(A_DATA) + + data = sock.recv_all(2) + self.assertEqual(data, b'OK') + + sock.send(B_DATA) + data = sock.recv_all(4) + self.assertEqual(data, b'SPAM') + + sock.close() + + except Exception as ex: + self.loop.call_soon_threadsafe(fut.set_exception, ex) + else: + self.loop.call_soon_threadsafe(fut.set_result, None) + + client = self.tcp_client(prog) + client.start() + clients.append(client) + + await fut + + async def start_server(): + extras = {} + extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT) + + srv = await asyncio.start_server( + handle_client, + '127.0.0.1', 0, + family=socket.AF_INET, + ssl=sslctx, + **extras) + + try: + srv_socks = srv.sockets + self.assertTrue(srv_socks) + + addr = srv_socks[0].getsockname() + + tasks = [] + for _ in range(TOTAL_CNT): + tasks.append(test_client(addr)) + + await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) + + finally: + self.loop.call_soon(srv.close) + await srv.wait_closed() + + with self._silence_eof_received_warning(): + self.loop.run_until_complete(start_server()) + + self.assertEqual(CNT, TOTAL_CNT) + + for client in clients: + client.stop() + + def test_create_connection_ssl_1(self): + self.loop.set_exception_handler(None) + + CNT = 0 + TOTAL_CNT = 25 + + A_DATA = b'A' * 1024 * BUF_MULTIPLIER + B_DATA = b'B' * 1024 * BUF_MULTIPLIER + + sslctx = self._create_server_ssl_context( + test_utils.ONLYCERT, + test_utils.ONLYKEY + ) + client_sslctx = self._create_client_ssl_context() + + def server(sock): + sock.starttls( + sslctx, + server_side=True) + + data = sock.recv_all(len(A_DATA)) + self.assertEqual(data, A_DATA) + sock.send(b'OK') + + data = sock.recv_all(len(B_DATA)) + self.assertEqual(data, B_DATA) + sock.send(b'SPAM') + + sock.close() + + async def client(addr): + extras = {} + extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT) + + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + **extras) + + writer.write(A_DATA) + self.assertEqual(await reader.readexactly(2), b'OK') + + writer.write(B_DATA) + self.assertEqual(await reader.readexactly(4), b'SPAM') + + nonlocal CNT + CNT += 1 + + writer.close() + await self.wait_closed(writer) + + async def client_sock(addr): + sock = socket.socket() + sock.connect(addr) + reader, writer = await asyncio.open_connection( + sock=sock, + ssl=client_sslctx, + server_hostname='') + + writer.write(A_DATA) + self.assertEqual(await reader.readexactly(2), b'OK') + + writer.write(B_DATA) + self.assertEqual(await reader.readexactly(4), b'SPAM') + + nonlocal CNT + CNT += 1 + + writer.close() + await self.wait_closed(writer) + sock.close() + + def run(coro): + nonlocal CNT + CNT = 0 + + async def _gather(*tasks): + # trampoline + return await asyncio.gather(*tasks) + + with self.tcp_server(server, + max_clients=TOTAL_CNT, + backlog=TOTAL_CNT) as srv: + tasks = [] + for _ in range(TOTAL_CNT): + tasks.append(coro(srv.addr)) + + self.loop.run_until_complete(_gather(*tasks)) + + self.assertEqual(CNT, TOTAL_CNT) + + with self._silence_eof_received_warning(): + run(client) + + with self._silence_eof_received_warning(): + run(client_sock) + + def test_create_connection_ssl_slow_handshake(self): + client_sslctx = self._create_client_ssl_context() + + # silence error logger + self.loop.set_exception_handler(lambda *args: None) + + def server(sock): + try: + sock.recv_all(1024 * 1024) + except ConnectionAbortedError: + pass + finally: + sock.close() + + async def client(addr): + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + ssl_handshake_timeout=1.0) + writer.close() + await self.wait_closed(writer) + + with self.tcp_server(server, + max_clients=1, + backlog=1) as srv: + + with self.assertRaisesRegex( + ConnectionAbortedError, + r'SSL handshake.*is taking longer'): + + self.loop.run_until_complete(client(srv.addr)) + + def test_create_connection_ssl_failed_certificate(self): + # silence error logger + self.loop.set_exception_handler(lambda *args: None) + + sslctx = self._create_server_ssl_context( + test_utils.ONLYCERT, + test_utils.ONLYKEY + ) + client_sslctx = self._create_client_ssl_context(disable_verify=False) + + def server(sock): + try: + sock.starttls( + sslctx, + server_side=True) + sock.connect() + except (ssl.SSLError, OSError): + pass + finally: + sock.close() + + async def client(addr): + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + ssl_handshake_timeout=support.SHORT_TIMEOUT) + writer.close() + await self.wait_closed(writer) + + with self.tcp_server(server, + max_clients=1, + backlog=1) as srv: + + with self.assertRaises(ssl.SSLCertVerificationError): + self.loop.run_until_complete(client(srv.addr)) + + def test_ssl_handshake_timeout(self): + # bpo-29970: Check that a connection is aborted if handshake is not + # completed in timeout period, instead of remaining open indefinitely + client_sslctx = test_utils.simple_client_sslcontext() + + # silence error logger + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + server_side_aborted = False + + def server(sock): + nonlocal server_side_aborted + try: + sock.recv_all(1024 * 1024) + except ConnectionAbortedError: + server_side_aborted = True + finally: + sock.close() + + async def client(addr): + await asyncio.wait_for( + self.loop.create_connection( + asyncio.Protocol, + *addr, + ssl=client_sslctx, + server_hostname='', + ssl_handshake_timeout=10.0), + 0.5) + + with self.tcp_server(server, + max_clients=1, + backlog=1) as srv: + + with self.assertRaises(asyncio.TimeoutError): + self.loop.run_until_complete(client(srv.addr)) + + self.assertTrue(server_side_aborted) + + # Python issue #23197: cancelling a handshake must not raise an + # exception or log an error, even if the handshake failed + self.assertEqual(messages, []) + + def test_ssl_handshake_connection_lost(self): + # #246: make sure that no connection_lost() is called before + # connection_made() is called first + + client_sslctx = test_utils.simple_client_sslcontext() + + # silence error logger + self.loop.set_exception_handler(lambda loop, ctx: None) + + connection_made_called = False + connection_lost_called = False + + def server(sock): + sock.recv(1024) + # break the connection during handshake + sock.close() + + class ClientProto(asyncio.Protocol): + def connection_made(self, transport): + nonlocal connection_made_called + connection_made_called = True + + def connection_lost(self, exc): + nonlocal connection_lost_called + connection_lost_called = True + + async def client(addr): + await self.loop.create_connection( + ClientProto, + *addr, + ssl=client_sslctx, + server_hostname=''), + + with self.tcp_server(server, + max_clients=1, + backlog=1) as srv: + + with self.assertRaises(ConnectionResetError): + self.loop.run_until_complete(client(srv.addr)) + + if connection_lost_called: + if connection_made_called: + self.fail("unexpected call to connection_lost()") + else: + self.fail("unexpected call to connection_lost() without" + "calling connection_made()") + elif connection_made_called: + self.fail("unexpected call to connection_made()") + + def test_ssl_connect_accepted_socket(self): + proto = ssl.PROTOCOL_TLS_SERVER + server_context = ssl.SSLContext(proto) + server_context.load_cert_chain(test_utils.ONLYCERT, test_utils.ONLYKEY) + if hasattr(server_context, 'check_hostname'): + server_context.check_hostname = False + server_context.verify_mode = ssl.CERT_NONE + + client_context = ssl.SSLContext(proto) + if hasattr(server_context, 'check_hostname'): + client_context.check_hostname = False + client_context.verify_mode = ssl.CERT_NONE + + def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None): + loop = self.loop + + class MyProto(MyBaseProto): + + def connection_lost(self, exc): + super().connection_lost(exc) + loop.call_soon(loop.stop) + + def data_received(self, data): + super().data_received(data) + self.transport.write(expected_response) + + lsock = socket.socket(socket.AF_INET) + lsock.bind(('127.0.0.1', 0)) + lsock.listen(1) + addr = lsock.getsockname() + + message = b'test data' + response = None + expected_response = b'roger' + + def client(): + nonlocal response + try: + csock = socket.socket(socket.AF_INET) + if client_ssl is not None: + csock = client_ssl.wrap_socket(csock) + csock.connect(addr) + csock.sendall(message) + response = csock.recv(99) + csock.close() + except Exception as exc: + print( + "Failure in client thread in test_connect_accepted_socket", + exc) + + thread = threading.Thread(target=client, daemon=True) + thread.start() + + conn, _ = lsock.accept() + proto = MyProto(loop=loop) + proto.loop = loop + + extras = {} + if server_ssl: + extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT) + + f = loop.create_task( + loop.connect_accepted_socket( + (lambda: proto), conn, ssl=server_ssl, + **extras)) + loop.run_forever() + conn.close() + lsock.close() + + thread.join(1) + self.assertFalse(thread.is_alive()) + self.assertEqual(proto.state, 'CLOSED') + self.assertEqual(proto.nbytes, len(message)) + self.assertEqual(response, expected_response) + tr, _ = f.result() + + if server_ssl: + self.assertIn('SSL', tr.__class__.__name__) + + tr.close() + # let it close + self.loop.run_until_complete(asyncio.sleep(0.1)) + + def test_start_tls_client_corrupted_ssl(self): + self.loop.set_exception_handler(lambda loop, ctx: None) + + sslctx = test_utils.simple_server_sslcontext() + client_sslctx = test_utils.simple_client_sslcontext() + + def server(sock): + orig_sock = sock.dup() + try: + sock.starttls( + sslctx, + server_side=True) + sock.sendall(b'A\n') + sock.recv_all(1) + orig_sock.send(b'please corrupt the SSL connection') + except ssl.SSLError: + pass + finally: + sock.close() + orig_sock.close() + + async def client(addr): + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='') + + self.assertEqual(await reader.readline(), b'A\n') + writer.write(b'B') + with self.assertRaises(ssl.SSLError): + await reader.readline() + writer.close() + try: + await self.wait_closed(writer) + except ssl.SSLError: + pass + return 'OK' + + with self.tcp_server(server, + max_clients=1, + backlog=1) as srv: + + res = self.loop.run_until_complete(client(srv.addr)) + + self.assertEqual(res, 'OK') + + @unittest.skip('TODO: RUSTPYTHON') + # RuntimeError: Event loop stopped before Future completed. + def test_start_tls_client_reg_proto_1(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + + server_context = test_utils.simple_server_sslcontext() + client_context = test_utils.simple_client_sslcontext() + + def serve(sock): + sock.settimeout(self.TIMEOUT) + + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.starttls(server_context, server_side=True) + + sock.sendall(b'O') + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.unwrap() + sock.close() + + class ClientProto(asyncio.Protocol): + def __init__(self, on_data, on_eof): + self.on_data = on_data + self.on_eof = on_eof + self.con_made_cnt = 0 + + def connection_made(proto, tr): + proto.con_made_cnt += 1 + # Ensure connection_made gets called only once. + self.assertEqual(proto.con_made_cnt, 1) + + def data_received(self, data): + self.on_data.set_result(data) + + def eof_received(self): + self.on_eof.set_result(True) + + async def client(addr): + await asyncio.sleep(0.5) + + on_data = self.loop.create_future() + on_eof = self.loop.create_future() + + tr, proto = await self.loop.create_connection( + lambda: ClientProto(on_data, on_eof), *addr) + + tr.write(HELLO_MSG) + new_tr = await self.loop.start_tls(tr, proto, client_context) + + self.assertEqual(await on_data, b'O') + new_tr.write(HELLO_MSG) + await on_eof + + new_tr.close() + + with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: + self.loop.run_until_complete( + asyncio.wait_for(client(srv.addr), + timeout=support.SHORT_TIMEOUT)) + + @unittest.skip('TODO: RUSTPYTHON') + # RuntimeError: Event loop stopped before Future completed. + def test_create_connection_memory_leak(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + + server_context = self._create_server_ssl_context( + test_utils.ONLYCERT, test_utils.ONLYKEY) + client_context = self._create_client_ssl_context() + + def serve(sock): + sock.settimeout(self.TIMEOUT) + + sock.starttls(server_context, server_side=True) + + sock.sendall(b'O') + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.unwrap() + sock.close() + + class ClientProto(asyncio.Protocol): + def __init__(self, on_data, on_eof): + self.on_data = on_data + self.on_eof = on_eof + self.con_made_cnt = 0 + + def connection_made(proto, tr): + # XXX: We assume user stores the transport in protocol + proto.tr = tr + proto.con_made_cnt += 1 + # Ensure connection_made gets called only once. + self.assertEqual(proto.con_made_cnt, 1) + + def data_received(self, data): + self.on_data.set_result(data) + + def eof_received(self): + self.on_eof.set_result(True) + + async def client(addr): + await asyncio.sleep(0.5) + + on_data = self.loop.create_future() + on_eof = self.loop.create_future() + + tr, proto = await self.loop.create_connection( + lambda: ClientProto(on_data, on_eof), *addr, + ssl=client_context) + + self.assertEqual(await on_data, b'O') + tr.write(HELLO_MSG) + await on_eof + + tr.close() + + with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: + self.loop.run_until_complete( + asyncio.wait_for(client(srv.addr), + timeout=support.SHORT_TIMEOUT)) + + # No garbage is left for SSL client from loop.create_connection, even + # if user stores the SSLTransport in corresponding protocol instance + client_context = weakref.ref(client_context) + self.assertIsNone(client_context()) + + @unittest.skip('TODO: RUSTPYTHON') + # RuntimeError: Event loop stopped before Future completed. + def test_start_tls_client_buf_proto_1(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + + server_context = test_utils.simple_server_sslcontext() + client_context = test_utils.simple_client_sslcontext() + + client_con_made_calls = 0 + + def serve(sock): + sock.settimeout(self.TIMEOUT) + + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.starttls(server_context, server_side=True) + + sock.sendall(b'O') + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.sendall(b'2') + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.unwrap() + sock.close() + + class ClientProtoFirst(asyncio.BufferedProtocol): + def __init__(self, on_data): + self.on_data = on_data + self.buf = bytearray(1) + + def connection_made(self, tr): + nonlocal client_con_made_calls + client_con_made_calls += 1 + + def get_buffer(self, sizehint): + return self.buf + + def buffer_updated(self, nsize): + assert nsize == 1 + self.on_data.set_result(bytes(self.buf[:nsize])) + + def eof_received(self): + pass + + class ClientProtoSecond(asyncio.Protocol): + def __init__(self, on_data, on_eof): + self.on_data = on_data + self.on_eof = on_eof + self.con_made_cnt = 0 + + def connection_made(self, tr): + nonlocal client_con_made_calls + client_con_made_calls += 1 + + def data_received(self, data): + self.on_data.set_result(data) + + def eof_received(self): + self.on_eof.set_result(True) + + async def client(addr): + await asyncio.sleep(0.5) + + on_data1 = self.loop.create_future() + on_data2 = self.loop.create_future() + on_eof = self.loop.create_future() + + tr, proto = await self.loop.create_connection( + lambda: ClientProtoFirst(on_data1), *addr) + + tr.write(HELLO_MSG) + new_tr = await self.loop.start_tls(tr, proto, client_context) + + self.assertEqual(await on_data1, b'O') + new_tr.write(HELLO_MSG) + + new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof)) + self.assertEqual(await on_data2, b'2') + new_tr.write(HELLO_MSG) + await on_eof + + new_tr.close() + + # connection_made() should be called only once -- when + # we establish connection for the first time. Start TLS + # doesn't call connection_made() on application protocols. + self.assertEqual(client_con_made_calls, 1) + + with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: + self.loop.run_until_complete( + asyncio.wait_for(client(srv.addr), + timeout=self.TIMEOUT)) + + def test_start_tls_slow_client_cancel(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + + client_context = test_utils.simple_client_sslcontext() + server_waits_on_handshake = self.loop.create_future() + + def serve(sock): + sock.settimeout(self.TIMEOUT) + + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + try: + self.loop.call_soon_threadsafe( + server_waits_on_handshake.set_result, None) + data = sock.recv_all(1024 * 1024) + except ConnectionAbortedError: + pass + finally: + sock.close() + + class ClientProto(asyncio.Protocol): + def __init__(self, on_data, on_eof): + self.on_data = on_data + self.on_eof = on_eof + self.con_made_cnt = 0 + + def connection_made(proto, tr): + proto.con_made_cnt += 1 + # Ensure connection_made gets called only once. + self.assertEqual(proto.con_made_cnt, 1) + + def data_received(self, data): + self.on_data.set_result(data) + + def eof_received(self): + self.on_eof.set_result(True) + + async def client(addr): + await asyncio.sleep(0.5) + + on_data = self.loop.create_future() + on_eof = self.loop.create_future() + + tr, proto = await self.loop.create_connection( + lambda: ClientProto(on_data, on_eof), *addr) + + tr.write(HELLO_MSG) + + await server_waits_on_handshake + + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for( + self.loop.start_tls(tr, proto, client_context), + 0.5) + + with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: + self.loop.run_until_complete( + asyncio.wait_for(client(srv.addr), + timeout=support.SHORT_TIMEOUT)) + + @unittest.skip('TODO: RUSTPYTHON') + # RuntimeError: Event loop stopped before Future completed. + def test_start_tls_server_1(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + + server_context = test_utils.simple_server_sslcontext() + client_context = test_utils.simple_client_sslcontext() + + def client(sock, addr): + sock.settimeout(self.TIMEOUT) + + sock.connect(addr) + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.starttls(client_context) + sock.sendall(HELLO_MSG) + + sock.unwrap() + sock.close() + + class ServerProto(asyncio.Protocol): + def __init__(self, on_con, on_eof, on_con_lost): + self.on_con = on_con + self.on_eof = on_eof + self.on_con_lost = on_con_lost + self.data = b'' + + def connection_made(self, tr): + self.on_con.set_result(tr) + + def data_received(self, data): + self.data += data + + def eof_received(self): + self.on_eof.set_result(1) + + def connection_lost(self, exc): + if exc is None: + self.on_con_lost.set_result(None) + else: + self.on_con_lost.set_exception(exc) + + async def main(proto, on_con, on_eof, on_con_lost): + tr = await on_con + tr.write(HELLO_MSG) + + self.assertEqual(proto.data, b'') + + new_tr = await self.loop.start_tls( + tr, proto, server_context, + server_side=True, + ssl_handshake_timeout=self.TIMEOUT) + + await on_eof + await on_con_lost + self.assertEqual(proto.data, HELLO_MSG) + new_tr.close() + + async def run_main(): + on_con = self.loop.create_future() + on_eof = self.loop.create_future() + on_con_lost = self.loop.create_future() + proto = ServerProto(on_con, on_eof, on_con_lost) + + server = await self.loop.create_server( + lambda: proto, '127.0.0.1', 0) + addr = server.sockets[0].getsockname() + + with self.tcp_client(lambda sock: client(sock, addr), + timeout=self.TIMEOUT): + await asyncio.wait_for( + main(proto, on_con, on_eof, on_con_lost), + timeout=self.TIMEOUT) + + server.close() + await server.wait_closed() + + self.loop.run_until_complete(run_main()) + + @support.bigmemtest(size=25, memuse=90*2**20, dry_run=False) + def test_create_server_ssl_over_ssl(self, size): + CNT = 0 # number of clients that were successful + TOTAL_CNT = size # total number of clients that test will create + TIMEOUT = support.LONG_TIMEOUT # timeout for this test + + A_DATA = b'A' * 1024 * BUF_MULTIPLIER + B_DATA = b'B' * 1024 * BUF_MULTIPLIER + + sslctx_1 = self._create_server_ssl_context( + test_utils.ONLYCERT, test_utils.ONLYKEY) + client_sslctx_1 = self._create_client_ssl_context() + sslctx_2 = self._create_server_ssl_context( + test_utils.ONLYCERT, test_utils.ONLYKEY) + client_sslctx_2 = self._create_client_ssl_context() + + clients = [] + + async def handle_client(reader, writer): + nonlocal CNT + + data = await reader.readexactly(len(A_DATA)) + self.assertEqual(data, A_DATA) + writer.write(b'OK') + + data = await reader.readexactly(len(B_DATA)) + self.assertEqual(data, B_DATA) + writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) + + await writer.drain() + writer.close() + + CNT += 1 + + class ServerProtocol(asyncio.StreamReaderProtocol): + def connection_made(self, transport): + super_ = super() + transport.pause_reading() + fut = self._loop.create_task(self._loop.start_tls( + transport, self, sslctx_2, server_side=True)) + + def cb(_): + try: + tr = fut.result() + except Exception as ex: + super_.connection_lost(ex) + else: + super_.connection_made(tr) + fut.add_done_callback(cb) + + def server_protocol_factory(): + reader = asyncio.StreamReader() + protocol = ServerProtocol(reader, handle_client) + return protocol + + async def test_client(addr): + fut = asyncio.Future() + + def prog(sock): + try: + sock.connect(addr) + sock.starttls(client_sslctx_1) + + # because wrap_socket() doesn't work correctly on + # SSLSocket, we have to do the 2nd level SSL manually + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + sslobj = client_sslctx_2.wrap_bio(incoming, outgoing) + + def do(func, *args): + while True: + try: + rv = func(*args) + break + except ssl.SSLWantReadError: + if outgoing.pending: + sock.send(outgoing.read()) + incoming.write(sock.recv(65536)) + if outgoing.pending: + sock.send(outgoing.read()) + return rv + + do(sslobj.do_handshake) + + do(sslobj.write, A_DATA) + data = do(sslobj.read, 2) + self.assertEqual(data, b'OK') + + do(sslobj.write, B_DATA) + data = b'' + while True: + chunk = do(sslobj.read, 4) + if not chunk: + break + data += chunk + self.assertEqual(data, b'SPAM') + + do(sslobj.unwrap) + sock.close() + + except Exception as ex: + self.loop.call_soon_threadsafe(fut.set_exception, ex) + sock.close() + else: + self.loop.call_soon_threadsafe(fut.set_result, None) + + client = self.tcp_client(prog) + client.start() + clients.append(client) + + await fut + + async def start_server(): + extras = {} + + srv = await self.loop.create_server( + server_protocol_factory, + '127.0.0.1', 0, + family=socket.AF_INET, + ssl=sslctx_1, + **extras) + + try: + srv_socks = srv.sockets + self.assertTrue(srv_socks) + + addr = srv_socks[0].getsockname() + + tasks = [] + for _ in range(TOTAL_CNT): + tasks.append(test_client(addr)) + + await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) + + finally: + self.loop.call_soon(srv.close) + await srv.wait_closed() + + with self._silence_eof_received_warning(): + self.loop.run_until_complete(start_server()) + + self.assertEqual(CNT, TOTAL_CNT) + + for client in clients: + client.stop() + + @unittest.skip('TODO: RUSTPYTHON') + # RuntimeError: Event loop stopped before Future completed. + def test_shutdown_cleanly(self): + CNT = 0 + TOTAL_CNT = 25 + + A_DATA = b'A' * 1024 * BUF_MULTIPLIER + + sslctx = self._create_server_ssl_context( + test_utils.ONLYCERT, test_utils.ONLYKEY) + client_sslctx = self._create_client_ssl_context() + + def server(sock): + sock.starttls( + sslctx, + server_side=True) + + data = sock.recv_all(len(A_DATA)) + self.assertEqual(data, A_DATA) + sock.send(b'OK') + + sock.unwrap() + + sock.close() + + async def client(addr): + extras = {} + extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT) + + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + **extras) + + writer.write(A_DATA) + self.assertEqual(await reader.readexactly(2), b'OK') + + self.assertEqual(await reader.read(), b'') + + nonlocal CNT + CNT += 1 + + writer.close() + await self.wait_closed(writer) + + def run(coro): + nonlocal CNT + CNT = 0 + + async def _gather(*tasks): + return await asyncio.gather(*tasks) + + with self.tcp_server(server, + max_clients=TOTAL_CNT, + backlog=TOTAL_CNT) as srv: + tasks = [] + for _ in range(TOTAL_CNT): + tasks.append(coro(srv.addr)) + + self.loop.run_until_complete( + _gather(*tasks)) + + self.assertEqual(CNT, TOTAL_CNT) + + with self._silence_eof_received_warning(): + run(client) + + def test_flush_before_shutdown(self): + CHUNK = 1024 * 128 + SIZE = 32 + + sslctx = self._create_server_ssl_context( + test_utils.ONLYCERT, test_utils.ONLYKEY) + client_sslctx = self._create_client_ssl_context() + + future = None + + def server(sock): + sock.starttls(sslctx, server_side=True) + self.assertEqual(sock.recv_all(4), b'ping') + sock.send(b'pong') + time.sleep(0.5) # hopefully stuck the TCP buffer + data = sock.recv_all(CHUNK * SIZE) + self.assertEqual(len(data), CHUNK * SIZE) + sock.close() + + def run(meth): + def wrapper(sock): + try: + meth(sock) + except Exception as ex: + self.loop.call_soon_threadsafe(future.set_exception, ex) + else: + self.loop.call_soon_threadsafe(future.set_result, None) + return wrapper + + async def client(addr): + nonlocal future + future = self.loop.create_future() + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='') + sslprotocol = writer.transport._ssl_protocol + writer.write(b'ping') + data = await reader.readexactly(4) + self.assertEqual(data, b'pong') + + sslprotocol.pause_writing() + for _ in range(SIZE): + writer.write(b'x' * CHUNK) + + writer.close() + sslprotocol.resume_writing() + + await self.wait_closed(writer) + try: + data = await reader.read() + self.assertEqual(data, b'') + except ConnectionResetError: + pass + await future + + with self.tcp_server(run(server)) as srv: + self.loop.run_until_complete(client(srv.addr)) + + @unittest.skip('TODO: RUSTPYTHON') + # ssl_error.SSLError: cannot read after shutdown + def test_remote_shutdown_receives_trailing_data(self): + CHUNK = 1024 * 128 + SIZE = 32 + + sslctx = self._create_server_ssl_context( + test_utils.ONLYCERT, + test_utils.ONLYKEY + ) + client_sslctx = self._create_client_ssl_context() + future = None + + def server(sock): + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True) + + while True: + try: + sslobj.do_handshake() + except ssl.SSLWantReadError: + if outgoing.pending: + sock.send(outgoing.read()) + incoming.write(sock.recv(16384)) + else: + if outgoing.pending: + sock.send(outgoing.read()) + break + + while True: + try: + data = sslobj.read(4) + except ssl.SSLWantReadError: + incoming.write(sock.recv(16384)) + else: + break + + self.assertEqual(data, b'ping') + sslobj.write(b'pong') + sock.send(outgoing.read()) + + time.sleep(0.2) # wait for the peer to fill its backlog + + # send close_notify but don't wait for response + with self.assertRaises(ssl.SSLWantReadError): + sslobj.unwrap() + sock.send(outgoing.read()) + + # should receive all data + data_len = 0 + while True: + try: + chunk = len(sslobj.read(16384)) + data_len += chunk + except ssl.SSLWantReadError: + incoming.write(sock.recv(16384)) + except ssl.SSLZeroReturnError: + break + + self.assertEqual(data_len, CHUNK * SIZE) + + # verify that close_notify is received + sslobj.unwrap() + + sock.close() + + def eof_server(sock): + sock.starttls(sslctx, server_side=True) + self.assertEqual(sock.recv_all(4), b'ping') + sock.send(b'pong') + + time.sleep(0.2) # wait for the peer to fill its backlog + + # send EOF + sock.shutdown(socket.SHUT_WR) + + # should receive all data + data = sock.recv_all(CHUNK * SIZE) + self.assertEqual(len(data), CHUNK * SIZE) + + sock.close() + + async def client(addr): + nonlocal future + future = self.loop.create_future() + + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='') + writer.write(b'ping') + data = await reader.readexactly(4) + self.assertEqual(data, b'pong') + + # fill write backlog in a hacky way - renegotiation won't help + for _ in range(SIZE): + writer.transport._test__append_write_backlog(b'x' * CHUNK) + + try: + data = await reader.read() + self.assertEqual(data, b'') + except (BrokenPipeError, ConnectionResetError): + pass + + await future + + writer.close() + await self.wait_closed(writer) + + def run(meth): + def wrapper(sock): + try: + meth(sock) + except Exception as ex: + self.loop.call_soon_threadsafe(future.set_exception, ex) + else: + self.loop.call_soon_threadsafe(future.set_result, None) + return wrapper + + with self.tcp_server(run(server)) as srv: + self.loop.run_until_complete(client(srv.addr)) + + with self.tcp_server(run(eof_server)) as srv: + self.loop.run_until_complete(client(srv.addr)) + + @unittest.skip('TODO: RUSTPYTHON') + # ssl_error.SSLError: cannot read after shutdown + def test_remote_shutdown_receives_trailing_data_on_slow_socket(self): + # This test is the same as test_remote_shutdown_receives_trailing_data, + # except it simulates a socket that is not able to write data in time, + # thus triggering different code path in _SelectorSocketTransport. + # This triggers bug gh-115514, also tested using mocks in + # test.test_asyncio.test_selector_events.SelectorSocketTransportTests.test_write_buffer_after_close + # The slow path is triggered here by setting SO_SNDBUF, see code and comment below. + + CHUNK = 1024 * 128 + SIZE = 32 + + sslctx = self._create_server_ssl_context( + test_utils.ONLYCERT, + test_utils.ONLYKEY + ) + client_sslctx = self._create_client_ssl_context() + future = None + + def server(sock): + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True) + + while True: + try: + sslobj.do_handshake() + except ssl.SSLWantReadError: + if outgoing.pending: + sock.send(outgoing.read()) + incoming.write(sock.recv(16384)) + else: + if outgoing.pending: + sock.send(outgoing.read()) + break + + while True: + try: + data = sslobj.read(4) + except ssl.SSLWantReadError: + incoming.write(sock.recv(16384)) + else: + break + + self.assertEqual(data, b'ping') + sslobj.write(b'pong') + sock.send(outgoing.read()) + + time.sleep(0.2) # wait for the peer to fill its backlog + + # send close_notify but don't wait for response + with self.assertRaises(ssl.SSLWantReadError): + sslobj.unwrap() + sock.send(outgoing.read()) + + # should receive all data + data_len = 0 + while True: + try: + chunk = len(sslobj.read(16384)) + data_len += chunk + except ssl.SSLWantReadError: + incoming.write(sock.recv(16384)) + except ssl.SSLZeroReturnError: + break + + self.assertEqual(data_len, CHUNK * SIZE*2) + + # verify that close_notify is received + sslobj.unwrap() + + sock.close() + + def eof_server(sock): + sock.starttls(sslctx, server_side=True) + self.assertEqual(sock.recv_all(4), b'ping') + sock.send(b'pong') + + time.sleep(0.2) # wait for the peer to fill its backlog + + # send EOF + sock.shutdown(socket.SHUT_WR) + + # should receive all data + data = sock.recv_all(CHUNK * SIZE) + self.assertEqual(len(data), CHUNK * SIZE) + + sock.close() + + async def client(addr): + nonlocal future + future = self.loop.create_future() + + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='') + writer.write(b'ping') + data = await reader.readexactly(4) + self.assertEqual(data, b'pong') + + # fill write backlog in a hacky way - renegotiation won't help + for _ in range(SIZE*2): + writer.transport._test__append_write_backlog(b'x' * CHUNK) + + try: + data = await reader.read() + self.assertEqual(data, b'') + except (BrokenPipeError, ConnectionResetError): + pass + + # Make sure _SelectorSocketTransport enters the delayed write + # path in its `write` method by wrapping socket in a fake class + # that acts as if there is not enough space in socket buffer. + # This triggers bug gh-115514, also tested using mocks in + # test.test_asyncio.test_selector_events.SelectorSocketTransportTests.test_write_buffer_after_close + socket_transport = writer.transport._ssl_protocol._transport + + class SocketWrapper: + def __init__(self, sock) -> None: + self.sock = sock + + def __getattr__(self, name): + return getattr(self.sock, name) + + def send(self, data): + # Fake that our write buffer is full, send only half + to_send = len(data)//2 + return self.sock.send(data[:to_send]) + + def _fake_full_write_buffer(data): + if socket_transport._read_ready_cb is None and not isinstance(socket_transport._sock, SocketWrapper): + socket_transport._sock = SocketWrapper(socket_transport._sock) + return unittest.mock.DEFAULT + + with unittest.mock.patch.object( + socket_transport, "write", + wraps=socket_transport.write, + side_effect=_fake_full_write_buffer + ): + await future + + writer.close() + await self.wait_closed(writer) + + def run(meth): + def wrapper(sock): + try: + meth(sock) + except Exception as ex: + self.loop.call_soon_threadsafe(future.set_exception, ex) + else: + self.loop.call_soon_threadsafe(future.set_result, None) + return wrapper + + with self.tcp_server(run(server)) as srv: + self.loop.run_until_complete(client(srv.addr)) + + with self.tcp_server(run(eof_server)) as srv: + self.loop.run_until_complete(client(srv.addr)) + + def test_connect_timeout_warning(self): + s = socket.socket(socket.AF_INET) + s.bind(('127.0.0.1', 0)) + addr = s.getsockname() + + async def test(): + try: + await asyncio.wait_for( + self.loop.create_connection(asyncio.Protocol, + *addr, ssl=True), + 0.1) + except (ConnectionRefusedError, asyncio.TimeoutError): + pass + else: + self.fail('TimeoutError is not raised') + + with s: + try: + with self.assertWarns(ResourceWarning) as cm: + self.loop.run_until_complete(test()) + gc.collect() + gc.collect() + gc.collect() + except AssertionError as e: + self.assertEqual(str(e), 'ResourceWarning not triggered') + else: + self.fail('Unexpected ResourceWarning: {}'.format(cm.warning)) + + # TODO: RUSTPYTHON + # AssertionError: is not None + @unittest.expectedFailure + def test_handshake_timeout_handler_leak(self): + s = socket.socket(socket.AF_INET) + s.bind(('127.0.0.1', 0)) + s.listen(1) + addr = s.getsockname() + + async def test(ctx): + try: + await asyncio.wait_for( + self.loop.create_connection(asyncio.Protocol, *addr, + ssl=ctx), + 0.1) + except (ConnectionRefusedError, asyncio.TimeoutError): + pass + else: + self.fail('TimeoutError is not raised') + + with s: + ctx = ssl.create_default_context() + self.loop.run_until_complete(test(ctx)) + ctx = weakref.ref(ctx) + + # SSLProtocol should be DECREF to 0 + self.assertIsNone(ctx()) + + # TODO: RUSTPYTHON + # AssertionError: is not None + @unittest.expectedFailure + def test_shutdown_timeout_handler_leak(self): + loop = self.loop + + def server(sock): + sslctx = self._create_server_ssl_context( + test_utils.ONLYCERT, + test_utils.ONLYKEY + ) + sock = sslctx.wrap_socket(sock, server_side=True) + sock.recv(32) + sock.close() + + class Protocol(asyncio.Protocol): + def __init__(self): + self.fut = asyncio.Future(loop=loop) + + def connection_lost(self, exc): + self.fut.set_result(None) + + async def client(addr, ctx): + tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx) + tr.close() + await pr.fut + + with self.tcp_server(server) as srv: + ctx = self._create_client_ssl_context() + loop.run_until_complete(client(srv.addr, ctx)) + ctx = weakref.ref(ctx) + + # asyncio has no shutdown timeout, but it ends up with a circular + # reference loop - not ideal (introduces gc glitches), but at least + # not leaking + gc.collect() + gc.collect() + gc.collect() + + # SSLProtocol should be DECREF to 0 + self.assertIsNone(ctx()) + + def test_shutdown_timeout_handler_not_set(self): + loop = self.loop + eof = asyncio.Event() + extra = None + + def server(sock): + sslctx = self._create_server_ssl_context( + test_utils.ONLYCERT, + test_utils.ONLYKEY + ) + sock = sslctx.wrap_socket(sock, server_side=True) + sock.send(b'hello') + assert sock.recv(1024) == b'world' + sock.send(b'extra bytes') + # sending EOF here + sock.shutdown(socket.SHUT_WR) + loop.call_soon_threadsafe(eof.set) + # make sure we have enough time to reproduce the issue + assert sock.recv(1024) == b'' + sock.close() + + class Protocol(asyncio.Protocol): + def __init__(self): + self.fut = asyncio.Future(loop=loop) + self.transport = None + + def connection_made(self, transport): + self.transport = transport + + def data_received(self, data): + if data == b'hello': + self.transport.write(b'world') + # pause reading would make incoming data stay in the sslobj + self.transport.pause_reading() + else: + nonlocal extra + extra = data + + def connection_lost(self, exc): + if exc is None: + self.fut.set_result(None) + else: + self.fut.set_exception(exc) + + async def client(addr): + ctx = self._create_client_ssl_context() + tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx) + await eof.wait() + tr.resume_reading() + await pr.fut + tr.close() + assert extra == b'extra bytes' + + with self.tcp_server(server) as srv: + loop.run_until_complete(client(srv.addr)) + + +############################################################################### +# Socket Testing Utilities +############################################################################### + + +class TestSocketWrapper: + + def __init__(self, sock): + self.__sock = sock + + def recv_all(self, n): + buf = b'' + while len(buf) < n: + data = self.recv(n - len(buf)) + if data == b'': + raise ConnectionAbortedError + buf += data + return buf + + def starttls(self, ssl_context, *, + server_side=False, + server_hostname=None, + do_handshake_on_connect=True): + + assert isinstance(ssl_context, ssl.SSLContext) + + ssl_sock = ssl_context.wrap_socket( + self.__sock, server_side=server_side, + server_hostname=server_hostname, + do_handshake_on_connect=do_handshake_on_connect) + + if server_side: + ssl_sock.do_handshake() + + self.__sock.close() + self.__sock = ssl_sock + + def __getattr__(self, name): + return getattr(self.__sock, name) + + def __repr__(self): + return '<{} {!r}>'.format(type(self).__name__, self.__sock) + + +class SocketThread(threading.Thread): + + def stop(self): + self._active = False + self.join() + + def __enter__(self): + self.start() + return self + + def __exit__(self, *exc): + self.stop() + + +class TestThreadedClient(SocketThread): + + def __init__(self, test, sock, prog, timeout): + threading.Thread.__init__(self, None, None, 'test-client') + self.daemon = True + + self._timeout = timeout + self._sock = sock + self._active = True + self._prog = prog + self._test = test + + def run(self): + try: + self._prog(TestSocketWrapper(self._sock)) + except (KeyboardInterrupt, SystemExit): + raise + except BaseException as ex: + self._test._abort_socket_test(ex) + + +class TestThreadedServer(SocketThread): + + def __init__(self, test, sock, prog, timeout, max_clients): + threading.Thread.__init__(self, None, None, 'test-server') + self.daemon = True + + self._clients = 0 + self._finished_clients = 0 + self._max_clients = max_clients + self._timeout = timeout + self._sock = sock + self._active = True + + self._prog = prog + + self._s1, self._s2 = socket.socketpair() + self._s1.setblocking(False) + + self._test = test + + def stop(self): + try: + if self._s2 and self._s2.fileno() != -1: + try: + self._s2.send(b'stop') + except OSError: + pass + finally: + super().stop() + + def run(self): + try: + with self._sock: + self._sock.setblocking(False) + self._run() + finally: + self._s1.close() + self._s2.close() + + def _run(self): + while self._active: + if self._clients >= self._max_clients: + return + + r, w, x = select.select( + [self._sock, self._s1], [], [], self._timeout) + + if self._s1 in r: + return + + if self._sock in r: + try: + conn, addr = self._sock.accept() + except BlockingIOError: + continue + except socket.timeout: + if not self._active: + return + else: + raise + else: + self._clients += 1 + conn.settimeout(self._timeout) + try: + with conn: + self._handle_client(conn) + except (KeyboardInterrupt, SystemExit): + raise + except BaseException as ex: + self._active = False + try: + raise + finally: + self._test._abort_socket_test(ex) + + def _handle_client(self, sock): + self._prog(TestSocketWrapper(sock)) + + @property + def addr(self): + return self._sock.getsockname() + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py new file mode 100644 index 00000000000..996f1e64cc5 --- /dev/null +++ b/Lib/test/test_asyncio/test_sslproto.py @@ -0,0 +1,855 @@ +"""Tests for asyncio/sslproto.py.""" + +import logging +import socket +import unittest +import weakref +from test import support +from test.support import socket_helper +from unittest import mock +try: + import ssl +except ImportError: + ssl = None + +import asyncio +from asyncio import log +from asyncio import protocols +from asyncio import sslproto +from test.test_asyncio import utils as test_utils +from test.test_asyncio import functional as func_tests + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +@unittest.skipIf(ssl is None, 'No ssl module') +class SslProtoHandshakeTests(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = asyncio.new_event_loop() + self.set_event_loop(self.loop) + + def ssl_protocol(self, *, waiter=None, proto=None): + sslcontext = test_utils.dummy_ssl_context() + if proto is None: # app protocol + proto = asyncio.Protocol() + ssl_proto = sslproto.SSLProtocol(self.loop, proto, sslcontext, waiter, + ssl_handshake_timeout=0.1) + self.assertIs(ssl_proto._app_transport.get_protocol(), proto) + self.addCleanup(ssl_proto._app_transport.close) + return ssl_proto + + def connection_made(self, ssl_proto, *, do_handshake=None): + transport = mock.Mock() + sslobj = mock.Mock() + # emulate reading decompressed data + sslobj.read.side_effect = ssl.SSLWantReadError + sslobj.write.side_effect = ssl.SSLWantReadError + if do_handshake is not None: + sslobj.do_handshake = do_handshake + ssl_proto._sslobj = sslobj + ssl_proto.connection_made(transport) + return transport + + def test_handshake_timeout_zero(self): + sslcontext = test_utils.dummy_ssl_context() + app_proto = mock.Mock() + waiter = mock.Mock() + with self.assertRaisesRegex(ValueError, 'a positive number'): + sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter, + ssl_handshake_timeout=0) + + def test_handshake_timeout_negative(self): + sslcontext = test_utils.dummy_ssl_context() + app_proto = mock.Mock() + waiter = mock.Mock() + with self.assertRaisesRegex(ValueError, 'a positive number'): + sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter, + ssl_handshake_timeout=-10) + + def test_eof_received_waiter(self): + waiter = self.loop.create_future() + ssl_proto = self.ssl_protocol(waiter=waiter) + self.connection_made( + ssl_proto, + do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError) + ) + ssl_proto.eof_received() + test_utils.run_briefly(self.loop) + self.assertIsInstance(waiter.exception(), ConnectionResetError) + + def test_fatal_error_no_name_error(self): + # From issue #363. + # _fatal_error() generates a NameError if sslproto.py + # does not import base_events. + waiter = self.loop.create_future() + ssl_proto = self.ssl_protocol(waiter=waiter) + # Temporarily turn off error logging so as not to spoil test output. + log_level = log.logger.getEffectiveLevel() + log.logger.setLevel(logging.FATAL) + try: + ssl_proto._fatal_error(None) + finally: + # Restore error logging. + log.logger.setLevel(log_level) + + def test_connection_lost(self): + # From issue #472. + # yield from waiter hang if lost_connection was called. + waiter = self.loop.create_future() + ssl_proto = self.ssl_protocol(waiter=waiter) + self.connection_made( + ssl_proto, + do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError) + ) + ssl_proto.connection_lost(ConnectionAbortedError) + test_utils.run_briefly(self.loop) + self.assertIsInstance(waiter.exception(), ConnectionAbortedError) + + # TODO: RUSTPYTHON + # AssertionError: ConnectionResetError not raised + @unittest.expectedFailure + def test_connection_lost_when_busy(self): + # gh-118950: SSLProtocol.connection_lost not being called when OSError + # is thrown on asyncio.write. + sock = mock.Mock() + sock.fileno = mock.Mock(return_value=12345) + sock.send = mock.Mock(side_effect=BrokenPipeError) + + # construct StreamWriter chain that contains loop dependant logic this emulates + # what _make_ssl_transport() does in BaseSelectorEventLoop + reader = asyncio.StreamReader(limit=2 ** 16, loop=self.loop) + protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop) + ssl_proto = self.ssl_protocol(proto=protocol) + + # emulate reading decompressed data + sslobj = mock.Mock() + sslobj.read.side_effect = ssl.SSLWantReadError + sslobj.write.side_effect = ssl.SSLWantReadError + ssl_proto._sslobj = sslobj + + # emulate outgoing data + data = b'An interesting message' + + outgoing = mock.Mock() + outgoing.read = mock.Mock(return_value=data) + outgoing.pending = len(data) + ssl_proto._outgoing = outgoing + + # use correct socket transport to initialize the SSLProtocol + self.loop._make_socket_transport(sock, ssl_proto) + + transport = ssl_proto._app_transport + writer = asyncio.StreamWriter(transport, protocol, reader, self.loop) + + async def main(): + # writes data to transport + async def write(): + writer.write(data) + await writer.drain() + + # try to write for the first time + await write() + # try to write for the second time, this raises as the connection_lost + # callback should be done with error + with self.assertRaises(ConnectionResetError): + await write() + + self.loop.run_until_complete(main()) + + def test_close_during_handshake(self): + # bpo-29743 Closing transport during handshake process leaks socket + waiter = self.loop.create_future() + ssl_proto = self.ssl_protocol(waiter=waiter) + + transport = self.connection_made( + ssl_proto, + do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError) + ) + test_utils.run_briefly(self.loop) + + ssl_proto._app_transport.close() + self.assertTrue(transport._force_close.called) + + def test_close_during_ssl_over_ssl(self): + # gh-113214: passing exceptions from the inner wrapped SSL protocol to the + # shim transport provided by the outer SSL protocol should not raise + # attribute errors + outer = self.ssl_protocol(proto=self.ssl_protocol()) + self.connection_made(outer) + # Closing the outer app transport should not raise an exception + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + outer._app_transport.close() + self.assertEqual(messages, []) + + def test_get_extra_info_on_closed_connection(self): + waiter = self.loop.create_future() + ssl_proto = self.ssl_protocol(waiter=waiter) + self.assertIsNone(ssl_proto._get_extra_info('socket')) + default = object() + self.assertIs(ssl_proto._get_extra_info('socket', default), default) + self.connection_made(ssl_proto) + self.assertIsNotNone(ssl_proto._get_extra_info('socket')) + ssl_proto.connection_lost(None) + self.assertIsNone(ssl_proto._get_extra_info('socket')) + + def test_set_new_app_protocol(self): + waiter = self.loop.create_future() + ssl_proto = self.ssl_protocol(waiter=waiter) + new_app_proto = asyncio.Protocol() + ssl_proto._app_transport.set_protocol(new_app_proto) + self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto) + self.assertIs(ssl_proto._app_protocol, new_app_proto) + + def test_data_received_after_closing(self): + ssl_proto = self.ssl_protocol() + self.connection_made(ssl_proto) + transp = ssl_proto._app_transport + + transp.close() + + # should not raise + self.assertIsNone(ssl_proto.buffer_updated(5)) + + def test_write_after_closing(self): + ssl_proto = self.ssl_protocol() + self.connection_made(ssl_proto) + transp = ssl_proto._app_transport + transp.close() + + # should not raise + self.assertIsNone(transp.write(b'data')) + + +############################################################################## +# Start TLS Tests +############################################################################## + + +class BaseStartTLS(func_tests.FunctionalTestCaseMixin): + + PAYLOAD_SIZE = 1024 * 100 + TIMEOUT = support.LONG_TIMEOUT + + def new_loop(self): + raise NotImplementedError + + def test_buf_feed_data(self): + + class Proto(asyncio.BufferedProtocol): + + def __init__(self, bufsize, usemv): + self.buf = bytearray(bufsize) + self.mv = memoryview(self.buf) + self.data = b'' + self.usemv = usemv + + def get_buffer(self, sizehint): + if self.usemv: + return self.mv + else: + return self.buf + + def buffer_updated(self, nsize): + if self.usemv: + self.data += self.mv[:nsize] + else: + self.data += self.buf[:nsize] + + for usemv in [False, True]: + proto = Proto(1, usemv) + protocols._feed_data_to_buffered_proto(proto, b'12345') + self.assertEqual(proto.data, b'12345') + + proto = Proto(2, usemv) + protocols._feed_data_to_buffered_proto(proto, b'12345') + self.assertEqual(proto.data, b'12345') + + proto = Proto(2, usemv) + protocols._feed_data_to_buffered_proto(proto, b'1234') + self.assertEqual(proto.data, b'1234') + + proto = Proto(4, usemv) + protocols._feed_data_to_buffered_proto(proto, b'1234') + self.assertEqual(proto.data, b'1234') + + proto = Proto(100, usemv) + protocols._feed_data_to_buffered_proto(proto, b'12345') + self.assertEqual(proto.data, b'12345') + + proto = Proto(0, usemv) + with self.assertRaisesRegex(RuntimeError, 'empty buffer'): + protocols._feed_data_to_buffered_proto(proto, b'12345') + + # TODO: RUSTPYTHON + # AssertionError: is not None + @unittest.expectedFailure + def test_start_tls_client_reg_proto_1(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + + server_context = test_utils.simple_server_sslcontext() + client_context = test_utils.simple_client_sslcontext() + + def serve(sock): + sock.settimeout(self.TIMEOUT) + + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.start_tls(server_context, server_side=True) + + sock.sendall(b'O') + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.shutdown(socket.SHUT_RDWR) + sock.close() + + class ClientProto(asyncio.Protocol): + def __init__(self, on_data, on_eof): + self.on_data = on_data + self.on_eof = on_eof + self.con_made_cnt = 0 + + def connection_made(proto, tr): + proto.con_made_cnt += 1 + # Ensure connection_made gets called only once. + self.assertEqual(proto.con_made_cnt, 1) + + def data_received(self, data): + self.on_data.set_result(data) + + def eof_received(self): + self.on_eof.set_result(True) + + async def client(addr): + await asyncio.sleep(0.5) + + on_data = self.loop.create_future() + on_eof = self.loop.create_future() + + tr, proto = await self.loop.create_connection( + lambda: ClientProto(on_data, on_eof), *addr) + + tr.write(HELLO_MSG) + new_tr = await self.loop.start_tls(tr, proto, client_context) + + self.assertEqual(await on_data, b'O') + new_tr.write(HELLO_MSG) + await on_eof + + new_tr.close() + + with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: + self.loop.run_until_complete( + asyncio.wait_for(client(srv.addr), + timeout=support.SHORT_TIMEOUT)) + + # No garbage is left if SSL is closed uncleanly + client_context = weakref.ref(client_context) + support.gc_collect() + self.assertIsNone(client_context()) + + # TODO: RUSTPYTHON + # AssertionError: is not None + @unittest.expectedFailure + def test_create_connection_memory_leak(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + + server_context = test_utils.simple_server_sslcontext() + client_context = test_utils.simple_client_sslcontext() + + def serve(sock): + sock.settimeout(self.TIMEOUT) + + sock.start_tls(server_context, server_side=True) + + sock.sendall(b'O') + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.shutdown(socket.SHUT_RDWR) + sock.close() + + class ClientProto(asyncio.Protocol): + def __init__(self, on_data, on_eof): + self.on_data = on_data + self.on_eof = on_eof + self.con_made_cnt = 0 + + def connection_made(proto, tr): + # XXX: We assume user stores the transport in protocol + proto.tr = tr + proto.con_made_cnt += 1 + # Ensure connection_made gets called only once. + self.assertEqual(proto.con_made_cnt, 1) + + def data_received(self, data): + self.on_data.set_result(data) + + def eof_received(self): + self.on_eof.set_result(True) + + async def client(addr): + await asyncio.sleep(0.5) + + on_data = self.loop.create_future() + on_eof = self.loop.create_future() + + tr, proto = await self.loop.create_connection( + lambda: ClientProto(on_data, on_eof), *addr, + ssl=client_context) + + self.assertEqual(await on_data, b'O') + tr.write(HELLO_MSG) + await on_eof + + tr.close() + + with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: + self.loop.run_until_complete( + asyncio.wait_for(client(srv.addr), + timeout=support.SHORT_TIMEOUT)) + + # No garbage is left for SSL client from loop.create_connection, even + # if user stores the SSLTransport in corresponding protocol instance + client_context = weakref.ref(client_context) + support.gc_collect() + self.assertIsNone(client_context()) + + @socket_helper.skip_if_tcp_blackhole + def test_start_tls_client_buf_proto_1(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + + server_context = test_utils.simple_server_sslcontext() + client_context = test_utils.simple_client_sslcontext() + client_con_made_calls = 0 + + def serve(sock): + sock.settimeout(self.TIMEOUT) + + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.start_tls(server_context, server_side=True) + + sock.sendall(b'O') + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.sendall(b'2') + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.shutdown(socket.SHUT_RDWR) + sock.close() + + class ClientProtoFirst(asyncio.BufferedProtocol): + def __init__(self, on_data): + self.on_data = on_data + self.buf = bytearray(1) + + def connection_made(self, tr): + nonlocal client_con_made_calls + client_con_made_calls += 1 + + def get_buffer(self, sizehint): + return self.buf + + def buffer_updated(slf, nsize): + self.assertEqual(nsize, 1) + slf.on_data.set_result(bytes(slf.buf[:nsize])) + + class ClientProtoSecond(asyncio.Protocol): + def __init__(self, on_data, on_eof): + self.on_data = on_data + self.on_eof = on_eof + self.con_made_cnt = 0 + + def connection_made(self, tr): + nonlocal client_con_made_calls + client_con_made_calls += 1 + + def data_received(self, data): + self.on_data.set_result(data) + + def eof_received(self): + self.on_eof.set_result(True) + + async def client(addr): + await asyncio.sleep(0.5) + + on_data1 = self.loop.create_future() + on_data2 = self.loop.create_future() + on_eof = self.loop.create_future() + + tr, proto = await self.loop.create_connection( + lambda: ClientProtoFirst(on_data1), *addr) + + tr.write(HELLO_MSG) + new_tr = await self.loop.start_tls(tr, proto, client_context) + + self.assertEqual(await on_data1, b'O') + new_tr.write(HELLO_MSG) + + new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof)) + self.assertEqual(await on_data2, b'2') + new_tr.write(HELLO_MSG) + await on_eof + + new_tr.close() + + # connection_made() should be called only once -- when + # we establish connection for the first time. Start TLS + # doesn't call connection_made() on application protocols. + self.assertEqual(client_con_made_calls, 1) + + with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: + self.loop.run_until_complete( + asyncio.wait_for(client(srv.addr), + timeout=self.TIMEOUT)) + + def test_start_tls_slow_client_cancel(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + + client_context = test_utils.simple_client_sslcontext() + server_waits_on_handshake = self.loop.create_future() + + def serve(sock): + sock.settimeout(self.TIMEOUT) + + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + try: + self.loop.call_soon_threadsafe( + server_waits_on_handshake.set_result, None) + data = sock.recv_all(1024 * 1024) + except ConnectionAbortedError: + pass + finally: + sock.close() + + class ClientProto(asyncio.Protocol): + def __init__(self, on_data, on_eof): + self.on_data = on_data + self.on_eof = on_eof + self.con_made_cnt = 0 + + def connection_made(proto, tr): + proto.con_made_cnt += 1 + # Ensure connection_made gets called only once. + self.assertEqual(proto.con_made_cnt, 1) + + def data_received(self, data): + self.on_data.set_result(data) + + def eof_received(self): + self.on_eof.set_result(True) + + async def client(addr): + await asyncio.sleep(0.5) + + on_data = self.loop.create_future() + on_eof = self.loop.create_future() + + tr, proto = await self.loop.create_connection( + lambda: ClientProto(on_data, on_eof), *addr) + + tr.write(HELLO_MSG) + + await server_waits_on_handshake + + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for( + self.loop.start_tls(tr, proto, client_context), + 0.5) + + with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: + self.loop.run_until_complete( + asyncio.wait_for(client(srv.addr), + timeout=support.SHORT_TIMEOUT)) + + @socket_helper.skip_if_tcp_blackhole + def test_start_tls_server_1(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + ANSWER = b'answer' + + server_context = test_utils.simple_server_sslcontext() + client_context = test_utils.simple_client_sslcontext() + answer = None + + def client(sock, addr): + nonlocal answer + sock.settimeout(self.TIMEOUT) + + sock.connect(addr) + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.start_tls(client_context) + sock.sendall(HELLO_MSG) + answer = sock.recv_all(len(ANSWER)) + sock.close() + + class ServerProto(asyncio.Protocol): + def __init__(self, on_con, on_con_lost, on_got_hello): + self.on_con = on_con + self.on_con_lost = on_con_lost + self.on_got_hello = on_got_hello + self.data = b'' + self.transport = None + + def connection_made(self, tr): + self.transport = tr + self.on_con.set_result(tr) + + def replace_transport(self, tr): + self.transport = tr + + def data_received(self, data): + self.data += data + if len(self.data) >= len(HELLO_MSG): + self.on_got_hello.set_result(None) + + def connection_lost(self, exc): + self.transport = None + if exc is None: + self.on_con_lost.set_result(None) + else: + self.on_con_lost.set_exception(exc) + + async def main(proto, on_con, on_con_lost, on_got_hello): + tr = await on_con + tr.write(HELLO_MSG) + + self.assertEqual(proto.data, b'') + + new_tr = await self.loop.start_tls( + tr, proto, server_context, + server_side=True, + ssl_handshake_timeout=self.TIMEOUT) + proto.replace_transport(new_tr) + + await on_got_hello + new_tr.write(ANSWER) + + await on_con_lost + self.assertEqual(proto.data, HELLO_MSG) + new_tr.close() + + async def run_main(): + on_con = self.loop.create_future() + on_con_lost = self.loop.create_future() + on_got_hello = self.loop.create_future() + proto = ServerProto(on_con, on_con_lost, on_got_hello) + + server = await self.loop.create_server( + lambda: proto, '127.0.0.1', 0) + addr = server.sockets[0].getsockname() + + with self.tcp_client(lambda sock: client(sock, addr), + timeout=self.TIMEOUT): + await asyncio.wait_for( + main(proto, on_con, on_con_lost, on_got_hello), + timeout=self.TIMEOUT) + + server.close() + await server.wait_closed() + self.assertEqual(answer, ANSWER) + + self.loop.run_until_complete(run_main()) + + def test_start_tls_wrong_args(self): + async def main(): + with self.assertRaisesRegex(TypeError, 'SSLContext, got'): + await self.loop.start_tls(None, None, None) + + sslctx = test_utils.simple_server_sslcontext() + with self.assertRaisesRegex(TypeError, 'is not supported'): + await self.loop.start_tls(None, None, sslctx) + + self.loop.run_until_complete(main()) + + # TODO: RUSTPYTHON + # AssertionError: is not None + @unittest.expectedFailure + def test_handshake_timeout(self): + # bpo-29970: Check that a connection is aborted if handshake is not + # completed in timeout period, instead of remaining open indefinitely + client_sslctx = test_utils.simple_client_sslcontext() + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + server_side_aborted = False + + def server(sock): + nonlocal server_side_aborted + try: + sock.recv_all(1024 * 1024) + except ConnectionAbortedError: + server_side_aborted = True + finally: + sock.close() + + async def client(addr): + await asyncio.wait_for( + self.loop.create_connection( + asyncio.Protocol, + *addr, + ssl=client_sslctx, + server_hostname='', + ssl_handshake_timeout=support.SHORT_TIMEOUT), + 0.5) + + with self.tcp_server(server, + max_clients=1, + backlog=1) as srv: + + with self.assertRaises(asyncio.TimeoutError): + self.loop.run_until_complete(client(srv.addr)) + + self.assertTrue(server_side_aborted) + + # Python issue #23197: cancelling a handshake must not raise an + # exception or log an error, even if the handshake failed + self.assertEqual(messages, []) + + # The 10s handshake timeout should be cancelled to free related + # objects without really waiting for 10s + client_sslctx = weakref.ref(client_sslctx) + support.gc_collect() + self.assertIsNone(client_sslctx()) + + def test_create_connection_ssl_slow_handshake(self): + client_sslctx = test_utils.simple_client_sslcontext() + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + def server(sock): + try: + sock.recv_all(1024 * 1024) + except ConnectionAbortedError: + pass + finally: + sock.close() + + async def client(addr): + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + ssl_handshake_timeout=1.0) + + with self.tcp_server(server, + max_clients=1, + backlog=1) as srv: + + with self.assertRaisesRegex( + ConnectionAbortedError, + r'SSL handshake.*is taking longer'): + + self.loop.run_until_complete(client(srv.addr)) + + self.assertEqual(messages, []) + + def test_create_connection_ssl_failed_certificate(self): + self.loop.set_exception_handler(lambda loop, ctx: None) + + sslctx = test_utils.simple_server_sslcontext() + client_sslctx = test_utils.simple_client_sslcontext( + disable_verify=False) + + def server(sock): + try: + sock.start_tls( + sslctx, + server_side=True) + except ssl.SSLError: + pass + except OSError: + pass + finally: + sock.close() + + async def client(addr): + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + ssl_handshake_timeout=support.LOOPBACK_TIMEOUT) + + with self.tcp_server(server, + max_clients=1, + backlog=1) as srv: + + with self.assertRaises(ssl.SSLCertVerificationError): + self.loop.run_until_complete(client(srv.addr)) + + def test_start_tls_client_corrupted_ssl(self): + self.loop.set_exception_handler(lambda loop, ctx: None) + + sslctx = test_utils.simple_server_sslcontext() + client_sslctx = test_utils.simple_client_sslcontext() + + def server(sock): + orig_sock = sock.dup() + try: + sock.start_tls( + sslctx, + server_side=True) + sock.sendall(b'A\n') + sock.recv_all(1) + orig_sock.send(b'please corrupt the SSL connection') + except ssl.SSLError: + pass + finally: + orig_sock.close() + sock.close() + + async def client(addr): + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='') + + self.assertEqual(await reader.readline(), b'A\n') + writer.write(b'B') + with self.assertRaises(ssl.SSLError): + await reader.readline() + + writer.close() + return 'OK' + + with self.tcp_server(server, + max_clients=1, + backlog=1) as srv: + + res = self.loop.run_until_complete(client(srv.addr)) + + self.assertEqual(res, 'OK') + + +@unittest.skipIf(ssl is None, 'No ssl module') +class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase): + + def new_loop(self): + return asyncio.SelectorEventLoop() + + +@unittest.skipIf(ssl is None, 'No ssl module') +@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only') +class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase): + + def new_loop(self): + return asyncio.ProactorEventLoop() + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_staggered.py b/Lib/test/test_asyncio/test_staggered.py new file mode 100644 index 00000000000..84857b1d58b --- /dev/null +++ b/Lib/test/test_asyncio/test_staggered.py @@ -0,0 +1,156 @@ +import asyncio +import unittest +from asyncio.staggered import staggered_race + +from test import support + +support.requires_working_socket(module=True) + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class StaggeredTests(unittest.IsolatedAsyncioTestCase): + async def test_empty(self): + winner, index, excs = await staggered_race( + [], + delay=None, + ) + + self.assertIs(winner, None) + self.assertIs(index, None) + self.assertEqual(excs, []) + + async def test_one_successful(self): + async def coro(index): + return f'Res: {index}' + + winner, index, excs = await staggered_race( + [ + lambda: coro(0), + lambda: coro(1), + ], + delay=None, + ) + + self.assertEqual(winner, 'Res: 0') + self.assertEqual(index, 0) + self.assertEqual(excs, [None]) + + async def test_first_error_second_successful(self): + async def coro(index): + if index == 0: + raise ValueError(index) + return f'Res: {index}' + + winner, index, excs = await staggered_race( + [ + lambda: coro(0), + lambda: coro(1), + ], + delay=None, + ) + + self.assertEqual(winner, 'Res: 1') + self.assertEqual(index, 1) + self.assertEqual(len(excs), 2) + self.assertIsInstance(excs[0], ValueError) + self.assertIs(excs[1], None) + + async def test_first_timeout_second_successful(self): + async def coro(index): + if index == 0: + await asyncio.sleep(10) # much bigger than delay + return f'Res: {index}' + + winner, index, excs = await staggered_race( + [ + lambda: coro(0), + lambda: coro(1), + ], + delay=0.1, + ) + + self.assertEqual(winner, 'Res: 1') + self.assertEqual(index, 1) + self.assertEqual(len(excs), 2) + self.assertIsInstance(excs[0], asyncio.CancelledError) + self.assertIs(excs[1], None) + + async def test_none_successful(self): + async def coro(index): + raise ValueError(index) + + winner, index, excs = await staggered_race( + [ + lambda: coro(0), + lambda: coro(1), + ], + delay=None, + ) + + self.assertIs(winner, None) + self.assertIs(index, None) + self.assertEqual(len(excs), 2) + self.assertIsInstance(excs[0], ValueError) + self.assertIsInstance(excs[1], ValueError) + + + async def test_multiple_winners(self): + event = asyncio.Event() + + async def coro(index): + await event.wait() + return index + + async def do_set(): + event.set() + await asyncio.Event().wait() + + winner, index, excs = await staggered_race( + [ + lambda: coro(0), + lambda: coro(1), + do_set, + ], + delay=0.1, + ) + self.assertIs(winner, 0) + self.assertIs(index, 0) + self.assertEqual(len(excs), 3) + self.assertIsNone(excs[0], None) + self.assertIsInstance(excs[1], asyncio.CancelledError) + self.assertIsInstance(excs[2], asyncio.CancelledError) + + # TODO: RUSTPYTHON + # AssertionError: Lists differ: ['cancelled 3'] != ['cancelled 1', 'cancelled 2', 'cancelled 3'] + @unittest.expectedFailure + async def test_cancelled(self): + log = [] + with self.assertRaises(TimeoutError): + async with asyncio.timeout(None) as cs_outer, asyncio.timeout(None) as cs_inner: + async def coro_fn(): + cs_inner.reschedule(-1) + await asyncio.sleep(0) + try: + await asyncio.sleep(0) + except asyncio.CancelledError: + log.append("cancelled 1") + + cs_outer.reschedule(-1) + await asyncio.sleep(0) + try: + await asyncio.sleep(0) + except asyncio.CancelledError: + log.append("cancelled 2") + try: + await staggered_race([coro_fn], delay=None) + except asyncio.CancelledError: + log.append("cancelled 3") + raise + + self.assertListEqual(log, ["cancelled 1", "cancelled 2", "cancelled 3"]) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py new file mode 100644 index 00000000000..6a90f9b2db2 --- /dev/null +++ b/Lib/test/test_asyncio/test_streams.py @@ -0,0 +1,1282 @@ +"""Tests for streams.py.""" + +import gc +import os +import queue +import pickle +import socket +import sys +import threading +import unittest +from unittest import mock +import warnings +try: + import ssl +except ImportError: + ssl = None + +import asyncio +from test.test_asyncio import utils as test_utils +from test.support import requires_subprocess, socket_helper + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class StreamTests(test_utils.TestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + super().setUp() + self.loop = asyncio.new_event_loop() + self.set_event_loop(self.loop) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + # set_event_loop() takes care of closing self.loop in a safe way + super().tearDown() + + def _basetest_open_connection(self, open_connection_fut): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + reader, writer = self.loop.run_until_complete(open_connection_fut) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + writer.close() + self.assertEqual(messages, []) + + def test_open_connection(self): + with test_utils.run_test_server() as httpd: + conn_fut = asyncio.open_connection(*httpd.address) + self._basetest_open_connection(conn_fut) + + @socket_helper.skip_unless_bind_unix_socket + def test_open_unix_connection(self): + with test_utils.run_test_unix_server() as httpd: + conn_fut = asyncio.open_unix_connection(httpd.address) + self._basetest_open_connection(conn_fut) + + def _basetest_open_connection_no_loop_ssl(self, open_connection_fut): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + try: + reader, writer = self.loop.run_until_complete(open_connection_fut) + finally: + asyncio.set_event_loop(None) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + + writer.close() + self.assertEqual(messages, []) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_open_connection_no_loop_ssl(self): + with test_utils.run_test_server(use_ssl=True) as httpd: + conn_fut = asyncio.open_connection( + *httpd.address, + ssl=test_utils.dummy_ssl_context()) + + self._basetest_open_connection_no_loop_ssl(conn_fut) + + @socket_helper.skip_unless_bind_unix_socket + @unittest.skipIf(ssl is None, 'No ssl module') + def test_open_unix_connection_no_loop_ssl(self): + with test_utils.run_test_unix_server(use_ssl=True) as httpd: + conn_fut = asyncio.open_unix_connection( + httpd.address, + ssl=test_utils.dummy_ssl_context(), + server_hostname='', + ) + + self._basetest_open_connection_no_loop_ssl(conn_fut) + + def _basetest_open_connection_error(self, open_connection_fut): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + reader, writer = self.loop.run_until_complete(open_connection_fut) + writer._protocol.connection_lost(ZeroDivisionError()) + f = reader.read() + with self.assertRaises(ZeroDivisionError): + self.loop.run_until_complete(f) + writer.close() + test_utils.run_briefly(self.loop) + self.assertEqual(messages, []) + + def test_open_connection_error(self): + with test_utils.run_test_server() as httpd: + conn_fut = asyncio.open_connection(*httpd.address) + self._basetest_open_connection_error(conn_fut) + + @socket_helper.skip_unless_bind_unix_socket + def test_open_unix_connection_error(self): + with test_utils.run_test_unix_server() as httpd: + conn_fut = asyncio.open_unix_connection(httpd.address) + self._basetest_open_connection_error(conn_fut) + + def test_feed_empty_data(self): + stream = asyncio.StreamReader(loop=self.loop) + + stream.feed_data(b'') + self.assertEqual(b'', stream._buffer) + + def test_feed_nonempty_data(self): + stream = asyncio.StreamReader(loop=self.loop) + + stream.feed_data(self.DATA) + self.assertEqual(self.DATA, stream._buffer) + + def test_read_zero(self): + # Read zero bytes. + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(self.DATA) + + data = self.loop.run_until_complete(stream.read(0)) + self.assertEqual(b'', data) + self.assertEqual(self.DATA, stream._buffer) + + def test_read(self): + # Read bytes. + stream = asyncio.StreamReader(loop=self.loop) + read_task = self.loop.create_task(stream.read(30)) + + def cb(): + stream.feed_data(self.DATA) + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertEqual(b'', stream._buffer) + + def test_read_line_breaks(self): + # Read bytes without line breaks. + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'line1') + stream.feed_data(b'line2') + + data = self.loop.run_until_complete(stream.read(5)) + + self.assertEqual(b'line1', data) + self.assertEqual(b'line2', stream._buffer) + + def test_read_eof(self): + # Read bytes, stop at eof. + stream = asyncio.StreamReader(loop=self.loop) + read_task = self.loop.create_task(stream.read(1024)) + + def cb(): + stream.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertEqual(b'', stream._buffer) + + def test_read_until_eof(self): + # Read all bytes until eof. + stream = asyncio.StreamReader(loop=self.loop) + read_task = self.loop.create_task(stream.read(-1)) + + def cb(): + stream.feed_data(b'chunk1\n') + stream.feed_data(b'chunk2') + stream.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + + self.assertEqual(b'chunk1\nchunk2', data) + self.assertEqual(b'', stream._buffer) + + def test_read_exception(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'line\n') + + data = self.loop.run_until_complete(stream.read(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.read(2)) + + def test_invalid_limit(self): + with self.assertRaisesRegex(ValueError, 'imit'): + asyncio.StreamReader(limit=0, loop=self.loop) + + with self.assertRaisesRegex(ValueError, 'imit'): + asyncio.StreamReader(limit=-1, loop=self.loop) + + def test_read_limit(self): + stream = asyncio.StreamReader(limit=3, loop=self.loop) + stream.feed_data(b'chunk') + data = self.loop.run_until_complete(stream.read(5)) + self.assertEqual(b'chunk', data) + self.assertEqual(b'', stream._buffer) + + def test_readline(self): + # Read one line. 'readline' will need to wait for the data + # to come from 'cb' + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'chunk1 ') + read_task = self.loop.create_task(stream.readline()) + + def cb(): + stream.feed_data(b'chunk2 ') + stream.feed_data(b'chunk3 ') + stream.feed_data(b'\n chunk4') + self.loop.call_soon(cb) + + line = self.loop.run_until_complete(read_task) + self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) + self.assertEqual(b' chunk4', stream._buffer) + + def test_readline_limit_with_existing_data(self): + # Read one line. The data is in StreamReader's buffer + # before the event loop is run. + + stream = asyncio.StreamReader(limit=3, loop=self.loop) + stream.feed_data(b'li') + stream.feed_data(b'ne1\nline2\n') + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + # The buffer should contain the remaining data after exception + self.assertEqual(b'line2\n', stream._buffer) + + stream = asyncio.StreamReader(limit=3, loop=self.loop) + stream.feed_data(b'li') + stream.feed_data(b'ne1') + stream.feed_data(b'li') + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + # No b'\n' at the end. The 'limit' is set to 3. So before + # waiting for the new data in buffer, 'readline' will consume + # the entire buffer, and since the length of the consumed data + # is more than 3, it will raise a ValueError. The buffer is + # expected to be empty now. + self.assertEqual(b'', stream._buffer) + + def test_at_eof(self): + stream = asyncio.StreamReader(loop=self.loop) + self.assertFalse(stream.at_eof()) + + stream.feed_data(b'some data\n') + self.assertFalse(stream.at_eof()) + + self.loop.run_until_complete(stream.readline()) + self.assertFalse(stream.at_eof()) + + stream.feed_data(b'some data\n') + stream.feed_eof() + self.loop.run_until_complete(stream.readline()) + self.assertTrue(stream.at_eof()) + + def test_readline_limit(self): + # Read one line. StreamReaders are fed with data after + # their 'readline' methods are called. + + stream = asyncio.StreamReader(limit=7, loop=self.loop) + def cb(): + stream.feed_data(b'chunk1') + stream.feed_data(b'chunk2') + stream.feed_data(b'chunk3\n') + stream.feed_eof() + self.loop.call_soon(cb) + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + # The buffer had just one line of data, and after raising + # a ValueError it should be empty. + self.assertEqual(b'', stream._buffer) + + stream = asyncio.StreamReader(limit=7, loop=self.loop) + def cb(): + stream.feed_data(b'chunk1') + stream.feed_data(b'chunk2\n') + stream.feed_data(b'chunk3\n') + stream.feed_eof() + self.loop.call_soon(cb) + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + self.assertEqual(b'chunk3\n', stream._buffer) + + # check strictness of the limit + stream = asyncio.StreamReader(limit=7, loop=self.loop) + stream.feed_data(b'1234567\n') + line = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'1234567\n', line) + self.assertEqual(b'', stream._buffer) + + stream.feed_data(b'12345678\n') + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'', stream._buffer) + + stream.feed_data(b'12345678') + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'', stream._buffer) + + def test_readline_nolimit_nowait(self): + # All needed data for the first 'readline' call will be + # in the buffer. + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(self.DATA[:6]) + stream.feed_data(self.DATA[6:]) + + line = self.loop.run_until_complete(stream.readline()) + + self.assertEqual(b'line1\n', line) + self.assertEqual(b'line2\nline3\n', stream._buffer) + + def test_readline_eof(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'some data') + stream.feed_eof() + + line = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'some data', line) + + def test_readline_empty_eof(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_eof() + + line = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'', line) + + def test_readline_read_byte_count(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(self.DATA) + + self.loop.run_until_complete(stream.readline()) + + data = self.loop.run_until_complete(stream.read(7)) + + self.assertEqual(b'line2\nl', data) + self.assertEqual(b'ine3\n', stream._buffer) + + def test_readline_exception(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'line\n') + + data = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'line\n', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + self.assertEqual(b'', stream._buffer) + + @unittest.skip('TODO: RUSTPYTHON') + # Causing a hang + def test_readuntil_separator(self): + stream = asyncio.StreamReader(loop=self.loop) + with self.assertRaisesRegex(ValueError, 'Separator should be'): + self.loop.run_until_complete(stream.readuntil(separator=b'')) + with self.assertRaisesRegex(ValueError, 'Separator should be'): + self.loop.run_until_complete(stream.readuntil(separator=(b'',))) + with self.assertRaisesRegex(ValueError, 'Separator should contain'): + self.loop.run_until_complete(stream.readuntil(separator=())) + + def test_readuntil_multi_chunks(self): + stream = asyncio.StreamReader(loop=self.loop) + + stream.feed_data(b'lineAAA') + data = self.loop.run_until_complete(stream.readuntil(separator=b'AAA')) + self.assertEqual(b'lineAAA', data) + self.assertEqual(b'', stream._buffer) + + stream.feed_data(b'lineAAA') + data = self.loop.run_until_complete(stream.readuntil(b'AAA')) + self.assertEqual(b'lineAAA', data) + self.assertEqual(b'', stream._buffer) + + stream.feed_data(b'lineAAAxxx') + data = self.loop.run_until_complete(stream.readuntil(b'AAA')) + self.assertEqual(b'lineAAA', data) + self.assertEqual(b'xxx', stream._buffer) + + def test_readuntil_multi_chunks_1(self): + stream = asyncio.StreamReader(loop=self.loop) + + stream.feed_data(b'QWEaa') + stream.feed_data(b'XYaa') + stream.feed_data(b'a') + data = self.loop.run_until_complete(stream.readuntil(b'aaa')) + self.assertEqual(b'QWEaaXYaaa', data) + self.assertEqual(b'', stream._buffer) + + stream.feed_data(b'QWEaa') + stream.feed_data(b'XYa') + stream.feed_data(b'aa') + data = self.loop.run_until_complete(stream.readuntil(b'aaa')) + self.assertEqual(b'QWEaaXYaaa', data) + self.assertEqual(b'', stream._buffer) + + stream.feed_data(b'aaa') + data = self.loop.run_until_complete(stream.readuntil(b'aaa')) + self.assertEqual(b'aaa', data) + self.assertEqual(b'', stream._buffer) + + stream.feed_data(b'Xaaa') + data = self.loop.run_until_complete(stream.readuntil(b'aaa')) + self.assertEqual(b'Xaaa', data) + self.assertEqual(b'', stream._buffer) + + stream.feed_data(b'XXX') + stream.feed_data(b'a') + stream.feed_data(b'a') + stream.feed_data(b'a') + data = self.loop.run_until_complete(stream.readuntil(b'aaa')) + self.assertEqual(b'XXXaaa', data) + self.assertEqual(b'', stream._buffer) + + def test_readuntil_eof(self): + stream = asyncio.StreamReader(loop=self.loop) + data = b'some dataAA' + stream.feed_data(data) + stream.feed_eof() + + with self.assertRaisesRegex(asyncio.IncompleteReadError, + 'undefined expected bytes') as cm: + self.loop.run_until_complete(stream.readuntil(b'AAA')) + self.assertEqual(cm.exception.partial, data) + self.assertIsNone(cm.exception.expected) + self.assertEqual(b'', stream._buffer) + + def test_readuntil_limit_found_sep(self): + stream = asyncio.StreamReader(loop=self.loop, limit=3) + stream.feed_data(b'some dataAA') + with self.assertRaisesRegex(asyncio.LimitOverrunError, + 'not found') as cm: + self.loop.run_until_complete(stream.readuntil(b'AAA')) + + self.assertEqual(b'some dataAA', stream._buffer) + + stream.feed_data(b'A') + with self.assertRaisesRegex(asyncio.LimitOverrunError, + 'is found') as cm: + self.loop.run_until_complete(stream.readuntil(b'AAA')) + + self.assertEqual(b'some dataAAA', stream._buffer) + + @unittest.skip('TODO: RUSTPYTHON') + # TypeError: unexpected type tuple + def test_readuntil_multi_separator(self): + stream = asyncio.StreamReader(loop=self.loop) + + # Simple case + stream.feed_data(b'line 1\nline 2\r') + data = self.loop.run_until_complete(stream.readuntil((b'\r', b'\n'))) + self.assertEqual(b'line 1\n', data) + data = self.loop.run_until_complete(stream.readuntil((b'\r', b'\n'))) + self.assertEqual(b'line 2\r', data) + self.assertEqual(b'', stream._buffer) + + # First end position matches, even if that's a longer match + stream.feed_data(b'ABCDEFG') + data = self.loop.run_until_complete(stream.readuntil((b'DEF', b'BCDE'))) + self.assertEqual(b'ABCDE', data) + self.assertEqual(b'FG', stream._buffer) + + @unittest.skip('TODO: RUSTPYTHON') + # TypeError: unexpected type tuple + def test_readuntil_multi_separator_limit(self): + stream = asyncio.StreamReader(loop=self.loop, limit=3) + stream.feed_data(b'some dataA') + + with self.assertRaisesRegex(asyncio.LimitOverrunError, + 'is found') as cm: + self.loop.run_until_complete(stream.readuntil((b'A', b'ome dataA'))) + + self.assertEqual(b'some dataA', stream._buffer) + + @unittest.skip('TODO: RUSTPYTHON') + # TypeError: unexpected type tuple + def test_readuntil_multi_separator_negative_offset(self): + # If the buffer is big enough for the smallest separator (but does + # not contain it) but too small for the largest, `offset` must not + # become negative. + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'data') + + readuntil_task = self.loop.create_task(stream.readuntil((b'A', b'long sep'))) + self.loop.call_soon(stream.feed_data, b'Z') + self.loop.call_soon(stream.feed_data, b'Aaaa') + + data = self.loop.run_until_complete(readuntil_task) + self.assertEqual(b'dataZA', data) + self.assertEqual(b'aaa', stream._buffer) + + def test_readuntil_bytearray(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'some data\r\n') + data = self.loop.run_until_complete(stream.readuntil(bytearray(b'\r\n'))) + self.assertEqual(b'some data\r\n', data) + self.assertEqual(b'', stream._buffer) + + def test_readexactly_zero_or_less(self): + # Read exact number of bytes (zero or less). + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(self.DATA) + + data = self.loop.run_until_complete(stream.readexactly(0)) + self.assertEqual(b'', data) + self.assertEqual(self.DATA, stream._buffer) + + with self.assertRaisesRegex(ValueError, 'less than zero'): + self.loop.run_until_complete(stream.readexactly(-1)) + self.assertEqual(self.DATA, stream._buffer) + + def test_readexactly(self): + # Read exact number of bytes. + stream = asyncio.StreamReader(loop=self.loop) + + n = 2 * len(self.DATA) + read_task = self.loop.create_task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(self.DATA + self.DATA, data) + self.assertEqual(self.DATA, stream._buffer) + + def test_readexactly_limit(self): + stream = asyncio.StreamReader(limit=3, loop=self.loop) + stream.feed_data(b'chunk') + data = self.loop.run_until_complete(stream.readexactly(5)) + self.assertEqual(b'chunk', data) + self.assertEqual(b'', stream._buffer) + + def test_readexactly_eof(self): + # Read exact number of bytes (eof). + stream = asyncio.StreamReader(loop=self.loop) + n = 2 * len(self.DATA) + read_task = self.loop.create_task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_eof() + self.loop.call_soon(cb) + + with self.assertRaises(asyncio.IncompleteReadError) as cm: + self.loop.run_until_complete(read_task) + self.assertEqual(cm.exception.partial, self.DATA) + self.assertEqual(cm.exception.expected, n) + self.assertEqual(str(cm.exception), + '18 bytes read on a total of 36 expected bytes') + self.assertEqual(b'', stream._buffer) + + def test_readexactly_exception(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'line\n') + + data = self.loop.run_until_complete(stream.readexactly(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readexactly(2)) + + def test_exception(self): + stream = asyncio.StreamReader(loop=self.loop) + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = asyncio.StreamReader(loop=self.loop) + + async def set_err(): + stream.set_exception(ValueError()) + + t1 = self.loop.create_task(stream.readline()) + t2 = self.loop.create_task(set_err()) + + self.loop.run_until_complete(asyncio.wait([t1, t2])) + + self.assertRaises(ValueError, t1.result) + + def test_exception_cancel(self): + stream = asyncio.StreamReader(loop=self.loop) + + t = self.loop.create_task(stream.readline()) + test_utils.run_briefly(self.loop) + t.cancel() + test_utils.run_briefly(self.loop) + # The following line fails if set_exception() isn't careful. + stream.set_exception(RuntimeError('message')) + test_utils.run_briefly(self.loop) + self.assertIs(stream._waiter, None) + + def test_start_server(self): + + class MyServer: + + def __init__(self, loop): + self.server = None + self.loop = loop + + async def handle_client(self, client_reader, client_writer): + data = await client_reader.readline() + client_writer.write(data) + await client_writer.drain() + client_writer.close() + await client_writer.wait_closed() + + def start(self): + sock = socket.create_server(('127.0.0.1', 0)) + self.server = self.loop.run_until_complete( + asyncio.start_server(self.handle_client, + sock=sock)) + return sock.getsockname() + + def handle_client_callback(self, client_reader, client_writer): + self.loop.create_task(self.handle_client(client_reader, + client_writer)) + + def start_callback(self): + sock = socket.create_server(('127.0.0.1', 0)) + addr = sock.getsockname() + sock.close() + self.server = self.loop.run_until_complete( + asyncio.start_server(self.handle_client_callback, + host=addr[0], port=addr[1])) + return addr + + def stop(self): + if self.server is not None: + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) + self.server = None + + async def client(addr): + reader, writer = await asyncio.open_connection(*addr) + # send a line + writer.write(b"hello world!\n") + # read it back + msgback = await reader.readline() + writer.close() + await writer.wait_closed() + return msgback + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + # test the server variant with a coroutine as client handler + server = MyServer(self.loop) + addr = server.start() + msg = self.loop.run_until_complete(self.loop.create_task(client(addr))) + server.stop() + self.assertEqual(msg, b"hello world!\n") + + # test the server variant with a callback as client handler + server = MyServer(self.loop) + addr = server.start_callback() + msg = self.loop.run_until_complete(self.loop.create_task(client(addr))) + server.stop() + self.assertEqual(msg, b"hello world!\n") + + self.assertEqual(messages, []) + + @socket_helper.skip_unless_bind_unix_socket + def test_start_unix_server(self): + + class MyServer: + + def __init__(self, loop, path): + self.server = None + self.loop = loop + self.path = path + + async def handle_client(self, client_reader, client_writer): + data = await client_reader.readline() + client_writer.write(data) + await client_writer.drain() + client_writer.close() + await client_writer.wait_closed() + + def start(self): + self.server = self.loop.run_until_complete( + asyncio.start_unix_server(self.handle_client, + path=self.path)) + + def handle_client_callback(self, client_reader, client_writer): + self.loop.create_task(self.handle_client(client_reader, + client_writer)) + + def start_callback(self): + start = asyncio.start_unix_server(self.handle_client_callback, + path=self.path) + self.server = self.loop.run_until_complete(start) + + def stop(self): + if self.server is not None: + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) + self.server = None + + async def client(path): + reader, writer = await asyncio.open_unix_connection(path) + # send a line + writer.write(b"hello world!\n") + # read it back + msgback = await reader.readline() + writer.close() + await writer.wait_closed() + return msgback + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + # test the server variant with a coroutine as client handler + with test_utils.unix_socket_path() as path: + server = MyServer(self.loop, path) + server.start() + msg = self.loop.run_until_complete( + self.loop.create_task(client(path))) + server.stop() + self.assertEqual(msg, b"hello world!\n") + + # test the server variant with a callback as client handler + with test_utils.unix_socket_path() as path: + server = MyServer(self.loop, path) + server.start_callback() + msg = self.loop.run_until_complete( + self.loop.create_task(client(path))) + server.stop() + self.assertEqual(msg, b"hello world!\n") + + self.assertEqual(messages, []) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_start_tls(self): + + class MyServer: + + def __init__(self, loop): + self.server = None + self.loop = loop + + async def handle_client(self, client_reader, client_writer): + data1 = await client_reader.readline() + client_writer.write(data1) + await client_writer.drain() + assert client_writer.get_extra_info('sslcontext') is None + await client_writer.start_tls( + test_utils.simple_server_sslcontext()) + assert client_writer.get_extra_info('sslcontext') is not None + data2 = await client_reader.readline() + client_writer.write(data2) + await client_writer.drain() + client_writer.close() + await client_writer.wait_closed() + + def start(self): + sock = socket.create_server(('127.0.0.1', 0)) + self.server = self.loop.run_until_complete( + asyncio.start_server(self.handle_client, + sock=sock)) + return sock.getsockname() + + def stop(self): + if self.server is not None: + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) + self.server = None + + async def client(addr): + reader, writer = await asyncio.open_connection(*addr) + writer.write(b"hello world 1!\n") + await writer.drain() + msgback1 = await reader.readline() + assert writer.get_extra_info('sslcontext') is None + await writer.start_tls(test_utils.simple_client_sslcontext()) + assert writer.get_extra_info('sslcontext') is not None + writer.write(b"hello world 2!\n") + await writer.drain() + msgback2 = await reader.readline() + writer.close() + await writer.wait_closed() + return msgback1, msgback2 + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + server = MyServer(self.loop) + addr = server.start() + msg1, msg2 = self.loop.run_until_complete(client(addr)) + server.stop() + + self.assertEqual(messages, []) + self.assertEqual(msg1, b"hello world 1!\n") + self.assertEqual(msg2, b"hello world 2!\n") + + @unittest.skipIf(sys.platform == 'win32', "Don't have pipes") + @requires_subprocess() + def test_read_all_from_pipe_reader(self): + # See asyncio issue 168. This test is derived from the example + # subprocess_attach_read_pipe.py, but we configure the + # StreamReader's limit so that twice it is less than the size + # of the data writer. Also we must explicitly attach a child + # watcher to the event loop. + + code = """\ +import os, sys +fd = int(sys.argv[1]) +os.write(fd, b'data') +os.close(fd) +""" + rfd, wfd = os.pipe() + args = [sys.executable, '-c', code, str(wfd)] + + pipe = open(rfd, 'rb', 0) + reader = asyncio.StreamReader(loop=self.loop, limit=1) + protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop) + transport, _ = self.loop.run_until_complete( + self.loop.connect_read_pipe(lambda: protocol, pipe)) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + watcher = asyncio.SafeChildWatcher() + watcher.attach_loop(self.loop) + try: + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + asyncio.set_child_watcher(watcher) + create = asyncio.create_subprocess_exec( + *args, + pass_fds={wfd}, + ) + proc = self.loop.run_until_complete(create) + self.loop.run_until_complete(proc.wait()) + finally: + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + asyncio.set_child_watcher(None) + + os.close(wfd) + data = self.loop.run_until_complete(reader.read(-1)) + self.assertEqual(data, b'data') + + def test_streamreader_constructor_without_loop(self): + with self.assertRaisesRegex(RuntimeError, 'no current event loop'): + asyncio.StreamReader() + + def test_streamreader_constructor_use_running_loop(self): + # asyncio issue #184: Ensure that StreamReaderProtocol constructor + # retrieves the current loop if the loop parameter is not set + async def test(): + return asyncio.StreamReader() + + reader = self.loop.run_until_complete(test()) + self.assertIs(reader._loop, self.loop) + + def test_streamreader_constructor_use_global_loop(self): + # asyncio issue #184: Ensure that StreamReaderProtocol constructor + # retrieves the current loop if the loop parameter is not set + # Deprecated in 3.10, undeprecated in 3.12 + self.addCleanup(asyncio.set_event_loop, None) + asyncio.set_event_loop(self.loop) + reader = asyncio.StreamReader() + self.assertIs(reader._loop, self.loop) + + + def test_streamreaderprotocol_constructor_without_loop(self): + reader = mock.Mock() + with self.assertRaisesRegex(RuntimeError, 'no current event loop'): + asyncio.StreamReaderProtocol(reader) + + def test_streamreaderprotocol_constructor_use_running_loop(self): + # asyncio issue #184: Ensure that StreamReaderProtocol constructor + # retrieves the current loop if the loop parameter is not set + reader = mock.Mock() + async def test(): + return asyncio.StreamReaderProtocol(reader) + protocol = self.loop.run_until_complete(test()) + self.assertIs(protocol._loop, self.loop) + + def test_streamreaderprotocol_constructor_use_global_loop(self): + # asyncio issue #184: Ensure that StreamReaderProtocol constructor + # retrieves the current loop if the loop parameter is not set + # Deprecated in 3.10, undeprecated in 3.12 + self.addCleanup(asyncio.set_event_loop, None) + asyncio.set_event_loop(self.loop) + reader = mock.Mock() + protocol = asyncio.StreamReaderProtocol(reader) + self.assertIs(protocol._loop, self.loop) + + def test_multiple_drain(self): + # See https://github.com/python/cpython/issues/74116 + drained = 0 + + async def drainer(stream): + nonlocal drained + await stream._drain_helper() + drained += 1 + + async def main(): + loop = asyncio.get_running_loop() + stream = asyncio.streams.FlowControlMixin(loop) + stream.pause_writing() + loop.call_later(0.1, stream.resume_writing) + await asyncio.gather(*[drainer(stream) for _ in range(10)]) + self.assertEqual(drained, 10) + + self.loop.run_until_complete(main()) + + def test_drain_raises(self): + # See http://bugs.python.org/issue25441 + + # This test should not use asyncio for the mock server; the + # whole point of the test is to test for a bug in drain() + # where it never gives up the event loop but the socket is + # closed on the server side. + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + q = queue.Queue() + + def server(): + # Runs in a separate thread. + with socket.create_server(('localhost', 0)) as sock: + addr = sock.getsockname() + q.put(addr) + clt, _ = sock.accept() + clt.close() + + async def client(host, port): + reader, writer = await asyncio.open_connection(host, port) + + while True: + writer.write(b"foo\n") + await writer.drain() + + # Start the server thread and wait for it to be listening. + thread = threading.Thread(target=server) + thread.daemon = True + thread.start() + addr = q.get() + + # Should not be stuck in an infinite loop. + with self.assertRaises((ConnectionResetError, ConnectionAbortedError, + BrokenPipeError)): + self.loop.run_until_complete(client(*addr)) + + # Clean up the thread. (Only on success; on failure, it may + # be stuck in accept().) + thread.join() + self.assertEqual([], messages) + + def test___repr__(self): + stream = asyncio.StreamReader(loop=self.loop) + self.assertEqual("", repr(stream)) + + def test___repr__nondefault_limit(self): + stream = asyncio.StreamReader(loop=self.loop, limit=123) + self.assertEqual("", repr(stream)) + + def test___repr__eof(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_eof() + self.assertEqual("", repr(stream)) + + def test___repr__data(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'data') + self.assertEqual("", repr(stream)) + + def test___repr__exception(self): + stream = asyncio.StreamReader(loop=self.loop) + exc = RuntimeError() + stream.set_exception(exc) + self.assertEqual("", + repr(stream)) + + def test___repr__waiter(self): + stream = asyncio.StreamReader(loop=self.loop) + stream._waiter = asyncio.Future(loop=self.loop) + self.assertRegex( + repr(stream), + r">") + stream._waiter.set_result(None) + self.loop.run_until_complete(stream._waiter) + stream._waiter = None + self.assertEqual("", repr(stream)) + + def test___repr__transport(self): + stream = asyncio.StreamReader(loop=self.loop) + stream._transport = mock.Mock() + stream._transport.__repr__ = mock.Mock() + stream._transport.__repr__.return_value = "" + self.assertEqual(">", repr(stream)) + + def test_IncompleteReadError_pickleable(self): + e = asyncio.IncompleteReadError(b'abc', 10) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(pickle_protocol=proto): + e2 = pickle.loads(pickle.dumps(e, protocol=proto)) + self.assertEqual(str(e), str(e2)) + self.assertEqual(e.partial, e2.partial) + self.assertEqual(e.expected, e2.expected) + + def test_LimitOverrunError_pickleable(self): + e = asyncio.LimitOverrunError('message', 10) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(pickle_protocol=proto): + e2 = pickle.loads(pickle.dumps(e, protocol=proto)) + self.assertEqual(str(e), str(e2)) + self.assertEqual(e.consumed, e2.consumed) + + def test_wait_closed_on_close(self): + with test_utils.run_test_server() as httpd: + rd, wr = self.loop.run_until_complete( + asyncio.open_connection(*httpd.address)) + + wr.write(b'GET / HTTP/1.0\r\n\r\n') + f = rd.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + f = rd.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + self.assertFalse(wr.is_closing()) + wr.close() + self.assertTrue(wr.is_closing()) + self.loop.run_until_complete(wr.wait_closed()) + + def test_wait_closed_on_close_with_unread_data(self): + with test_utils.run_test_server() as httpd: + rd, wr = self.loop.run_until_complete( + asyncio.open_connection(*httpd.address)) + + wr.write(b'GET / HTTP/1.0\r\n\r\n') + f = rd.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + wr.close() + self.loop.run_until_complete(wr.wait_closed()) + + def test_async_writer_api(self): + async def inner(httpd): + rd, wr = await asyncio.open_connection(*httpd.address) + + wr.write(b'GET / HTTP/1.0\r\n\r\n') + data = await rd.readline() + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + data = await rd.read() + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + wr.close() + await wr.wait_closed() + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + with test_utils.run_test_server() as httpd: + self.loop.run_until_complete(inner(httpd)) + + self.assertEqual(messages, []) + + def test_async_writer_api_exception_after_close(self): + async def inner(httpd): + rd, wr = await asyncio.open_connection(*httpd.address) + + wr.write(b'GET / HTTP/1.0\r\n\r\n') + data = await rd.readline() + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + data = await rd.read() + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + wr.close() + with self.assertRaises(ConnectionResetError): + wr.write(b'data') + await wr.drain() + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + with test_utils.run_test_server() as httpd: + self.loop.run_until_complete(inner(httpd)) + + self.assertEqual(messages, []) + + def test_eof_feed_when_closing_writer(self): + # See http://bugs.python.org/issue35065 + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + with test_utils.run_test_server() as httpd: + rd, wr = self.loop.run_until_complete( + asyncio.open_connection(*httpd.address)) + + wr.close() + f = wr.wait_closed() + self.loop.run_until_complete(f) + self.assertTrue(rd.at_eof()) + f = rd.read() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'') + + self.assertEqual(messages, []) + + def test_unclosed_resource_warnings(self): + async def inner(httpd): + rd, wr = await asyncio.open_connection(*httpd.address) + + wr.write(b'GET / HTTP/1.0\r\n\r\n') + data = await rd.readline() + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + data = await rd.read() + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + with self.assertWarns(ResourceWarning) as cm: + del wr + gc.collect() + self.assertEqual(len(cm.warnings), 1) + self.assertTrue(str(cm.warnings[0].message).startswith("unclosed }] != [] + @unittest.expectedFailure + def test_unclosed_server_resource_warnings(self): + async def inner(rd, wr): + fut.set_result(True) + with self.assertWarns(ResourceWarning) as cm: + del wr + gc.collect() + self.assertEqual(len(cm.warnings), 1) + self.assertTrue(str(cm.warnings[0].message).startswith("unclosed " + ) + transport._returncode = None + self.assertEqual( + repr(transport), + "" + ) + transport._pid = None + transport._returncode = None + self.assertEqual( + repr(transport), + "" + ) + transport.close() + + +class SubprocessMixin: + + def test_stdin_stdout(self): + args = PROGRAM_CAT + + async def run(data): + proc = await asyncio.create_subprocess_exec( + *args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + + # feed data + proc.stdin.write(data) + await proc.stdin.drain() + proc.stdin.close() + + # get output and exitcode + data = await proc.stdout.read() + exitcode = await proc.wait() + return (exitcode, data) + + task = run(b'some data') + task = asyncio.wait_for(task, 60.0) + exitcode, stdout = self.loop.run_until_complete(task) + self.assertEqual(exitcode, 0) + self.assertEqual(stdout, b'some data') + + def test_communicate(self): + args = PROGRAM_CAT + + async def run(data): + proc = await asyncio.create_subprocess_exec( + *args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + stdout, stderr = await proc.communicate(data) + return proc.returncode, stdout + + task = run(b'some data') + task = asyncio.wait_for(task, support.LONG_TIMEOUT) + exitcode, stdout = self.loop.run_until_complete(task) + self.assertEqual(exitcode, 0) + self.assertEqual(stdout, b'some data') + + def test_communicate_none_input(self): + args = PROGRAM_CAT + + async def run(): + proc = await asyncio.create_subprocess_exec( + *args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + return proc.returncode, stdout + + task = run() + task = asyncio.wait_for(task, support.LONG_TIMEOUT) + exitcode, stdout = self.loop.run_until_complete(task) + self.assertEqual(exitcode, 0) + self.assertEqual(stdout, b'') + + def test_shell(self): + proc = self.loop.run_until_complete( + asyncio.create_subprocess_shell('exit 7') + ) + exitcode = self.loop.run_until_complete(proc.wait()) + self.assertEqual(exitcode, 7) + + def test_start_new_session(self): + # start the new process in a new session + proc = self.loop.run_until_complete( + asyncio.create_subprocess_shell( + 'exit 8', + start_new_session=True, + ) + ) + exitcode = self.loop.run_until_complete(proc.wait()) + self.assertEqual(exitcode, 8) + + def test_kill(self): + args = PROGRAM_BLOCKED + proc = self.loop.run_until_complete( + asyncio.create_subprocess_exec(*args) + ) + proc.kill() + returncode = self.loop.run_until_complete(proc.wait()) + if sys.platform == 'win32': + self.assertIsInstance(returncode, int) + # expect 1 but sometimes get 0 + else: + self.assertEqual(-signal.SIGKILL, returncode) + + def test_kill_issue43884(self): + if sys.platform == 'win32': + blocking_shell_command = f'"{sys.executable}" -c "import time; time.sleep(2)"' + else: + blocking_shell_command = 'sleep 1; sleep 1' + creationflags = 0 + if sys.platform == 'win32': + from subprocess import CREATE_NEW_PROCESS_GROUP + # On windows create a new process group so that killing process + # kills the process and all its children. + creationflags = CREATE_NEW_PROCESS_GROUP + proc = self.loop.run_until_complete( + asyncio.create_subprocess_shell(blocking_shell_command, stdout=asyncio.subprocess.PIPE, + creationflags=creationflags) + ) + self.loop.run_until_complete(asyncio.sleep(1)) + if sys.platform == 'win32': + proc.send_signal(signal.CTRL_BREAK_EVENT) + # On windows it is an alias of terminate which sets the return code + proc.kill() + returncode = self.loop.run_until_complete(proc.wait()) + if sys.platform == 'win32': + self.assertIsInstance(returncode, int) + # expect 1 but sometimes get 0 + else: + self.assertEqual(-signal.SIGKILL, returncode) + + def test_terminate(self): + args = PROGRAM_BLOCKED + proc = self.loop.run_until_complete( + asyncio.create_subprocess_exec(*args) + ) + proc.terminate() + returncode = self.loop.run_until_complete(proc.wait()) + if sys.platform == 'win32': + self.assertIsInstance(returncode, int) + # expect 1 but sometimes get 0 + else: + self.assertEqual(-signal.SIGTERM, returncode) + + @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP") + def test_send_signal(self): + # bpo-31034: Make sure that we get the default signal handler (killing + # the process). The parent process may have decided to ignore SIGHUP, + # and signal handlers are inherited. + old_handler = signal.signal(signal.SIGHUP, signal.SIG_DFL) + try: + code = 'import time; print("sleeping", flush=True); time.sleep(3600)' + args = [sys.executable, '-c', code] + proc = self.loop.run_until_complete( + asyncio.create_subprocess_exec( + *args, + stdout=subprocess.PIPE, + ) + ) + + async def send_signal(proc): + # basic synchronization to wait until the program is sleeping + line = await proc.stdout.readline() + self.assertEqual(line, b'sleeping\n') + + proc.send_signal(signal.SIGHUP) + returncode = await proc.wait() + return returncode + + returncode = self.loop.run_until_complete(send_signal(proc)) + self.assertEqual(-signal.SIGHUP, returncode) + finally: + signal.signal(signal.SIGHUP, old_handler) + + def test_stdin_broken_pipe(self): + # buffer large enough to feed the whole pipe buffer + large_data = b'x' * support.PIPE_MAX_SIZE + + rfd, wfd = os.pipe() + self.addCleanup(os.close, rfd) + self.addCleanup(os.close, wfd) + if support.MS_WINDOWS: + handle = msvcrt.get_osfhandle(rfd) + os.set_handle_inheritable(handle, True) + code = textwrap.dedent(f''' + import os, msvcrt + handle = {handle} + fd = msvcrt.open_osfhandle(handle, os.O_RDONLY) + os.read(fd, 1) + ''') + from subprocess import STARTUPINFO + startupinfo = STARTUPINFO() + startupinfo.lpAttributeList = {"handle_list": [handle]} + kwargs = dict(startupinfo=startupinfo) + else: + code = f'import os; fd = {rfd}; os.read(fd, 1)' + kwargs = dict(pass_fds=(rfd,)) + + # the program ends before the stdin can be fed + proc = self.loop.run_until_complete( + asyncio.create_subprocess_exec( + sys.executable, '-c', code, + stdin=subprocess.PIPE, + **kwargs + ) + ) + + async def write_stdin(proc, data): + proc.stdin.write(data) + # Only exit the child process once the write buffer is filled + os.write(wfd, b'go') + await proc.stdin.drain() + + coro = write_stdin(proc, large_data) + # drain() must raise BrokenPipeError or ConnectionResetError + with test_utils.disable_logger(): + self.assertRaises((BrokenPipeError, ConnectionResetError), + self.loop.run_until_complete, coro) + self.loop.run_until_complete(proc.wait()) + + def test_communicate_ignore_broken_pipe(self): + # buffer large enough to feed the whole pipe buffer + large_data = b'x' * support.PIPE_MAX_SIZE + + # the program ends before the stdin can be fed + proc = self.loop.run_until_complete( + asyncio.create_subprocess_exec( + sys.executable, '-c', 'pass', + stdin=subprocess.PIPE, + ) + ) + + # communicate() must ignore BrokenPipeError when feeding stdin + self.loop.set_exception_handler(lambda loop, msg: None) + self.loop.run_until_complete(proc.communicate(large_data)) + self.loop.run_until_complete(proc.wait()) + + def test_pause_reading(self): + limit = 10 + size = (limit * 2 + 1) + + async def test_pause_reading(): + code = '\n'.join(( + 'import sys', + 'sys.stdout.write("x" * %s)' % size, + 'sys.stdout.flush()', + )) + + connect_read_pipe = self.loop.connect_read_pipe + + async def connect_read_pipe_mock(*args, **kw): + transport, protocol = await connect_read_pipe(*args, **kw) + transport.pause_reading = mock.Mock() + transport.resume_reading = mock.Mock() + return (transport, protocol) + + self.loop.connect_read_pipe = connect_read_pipe_mock + + proc = await asyncio.create_subprocess_exec( + sys.executable, '-c', code, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + limit=limit, + ) + stdout_transport = proc._transport.get_pipe_transport(1) + + stdout, stderr = await proc.communicate() + + # The child process produced more than limit bytes of output, + # the stream reader transport should pause the protocol to not + # allocate too much memory. + return (stdout, stdout_transport) + + # Issue #22685: Ensure that the stream reader pauses the protocol + # when the child process produces too much data + stdout, transport = self.loop.run_until_complete(test_pause_reading()) + + self.assertEqual(stdout, b'x' * size) + self.assertTrue(transport.pause_reading.called) + self.assertTrue(transport.resume_reading.called) + + def test_stdin_not_inheritable(self): + # asyncio issue #209: stdin must not be inheritable, otherwise + # the Process.communicate() hangs + async def len_message(message): + code = 'import sys; data = sys.stdin.read(); print(len(data))' + proc = await asyncio.create_subprocess_exec( + sys.executable, '-c', code, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + close_fds=False, + ) + stdout, stderr = await proc.communicate(message) + exitcode = await proc.wait() + return (stdout, exitcode) + + output, exitcode = self.loop.run_until_complete(len_message(b'abc')) + self.assertEqual(output.rstrip(), b'3') + self.assertEqual(exitcode, 0) + + def test_empty_input(self): + + async def empty_input(): + code = 'import sys; data = sys.stdin.read(); print(len(data))' + proc = await asyncio.create_subprocess_exec( + sys.executable, '-c', code, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + close_fds=False, + ) + stdout, stderr = await proc.communicate(b'') + exitcode = await proc.wait() + return (stdout, exitcode) + + output, exitcode = self.loop.run_until_complete(empty_input()) + self.assertEqual(output.rstrip(), b'0') + self.assertEqual(exitcode, 0) + + def test_devnull_input(self): + + async def empty_input(): + code = 'import sys; data = sys.stdin.read(); print(len(data))' + proc = await asyncio.create_subprocess_exec( + sys.executable, '-c', code, + stdin=asyncio.subprocess.DEVNULL, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + close_fds=False, + ) + stdout, stderr = await proc.communicate() + exitcode = await proc.wait() + return (stdout, exitcode) + + output, exitcode = self.loop.run_until_complete(empty_input()) + self.assertEqual(output.rstrip(), b'0') + self.assertEqual(exitcode, 0) + + def test_devnull_output(self): + + async def empty_output(): + code = 'import sys; data = sys.stdin.read(); print(len(data))' + proc = await asyncio.create_subprocess_exec( + sys.executable, '-c', code, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.PIPE, + close_fds=False, + ) + stdout, stderr = await proc.communicate(b"abc") + exitcode = await proc.wait() + return (stdout, exitcode) + + output, exitcode = self.loop.run_until_complete(empty_output()) + self.assertEqual(output, None) + self.assertEqual(exitcode, 0) + + def test_devnull_error(self): + + async def empty_error(): + code = 'import sys; data = sys.stdin.read(); print(len(data))' + proc = await asyncio.create_subprocess_exec( + sys.executable, '-c', code, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.DEVNULL, + close_fds=False, + ) + stdout, stderr = await proc.communicate(b"abc") + exitcode = await proc.wait() + return (stderr, exitcode) + + output, exitcode = self.loop.run_until_complete(empty_error()) + self.assertEqual(output, None) + self.assertEqual(exitcode, 0) + + @unittest.skipIf(sys.platform not in ('linux', 'android'), + "Don't have /dev/stdin") + def test_devstdin_input(self): + + async def devstdin_input(message): + code = 'file = open("/dev/stdin"); data = file.read(); print(len(data))' + proc = await asyncio.create_subprocess_exec( + sys.executable, '-c', code, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + close_fds=False, + ) + stdout, stderr = await proc.communicate(message) + exitcode = await proc.wait() + return (stdout, exitcode) + + output, exitcode = self.loop.run_until_complete(devstdin_input(b'abc')) + self.assertEqual(output.rstrip(), b'3') + self.assertEqual(exitcode, 0) + + def test_cancel_process_wait(self): + # Issue #23140: cancel Process.wait() + + async def cancel_wait(): + proc = await asyncio.create_subprocess_exec(*PROGRAM_BLOCKED) + + # Create an internal future waiting on the process exit + task = self.loop.create_task(proc.wait()) + self.loop.call_soon(task.cancel) + try: + await task + except asyncio.CancelledError: + pass + + # Cancel the future + task.cancel() + + # Kill the process and wait until it is done + proc.kill() + await proc.wait() + + self.loop.run_until_complete(cancel_wait()) + + def test_cancel_make_subprocess_transport_exec(self): + + async def cancel_make_transport(): + coro = asyncio.create_subprocess_exec(*PROGRAM_BLOCKED) + task = self.loop.create_task(coro) + + self.loop.call_soon(task.cancel) + try: + await task + except asyncio.CancelledError: + pass + + # ignore the log: + # "Exception during subprocess creation, kill the subprocess" + with test_utils.disable_logger(): + self.loop.run_until_complete(cancel_make_transport()) + + def test_cancel_post_init(self): + + async def cancel_make_transport(): + coro = self.loop.subprocess_exec(asyncio.SubprocessProtocol, + *PROGRAM_BLOCKED) + task = self.loop.create_task(coro) + + self.loop.call_soon(task.cancel) + try: + await task + except asyncio.CancelledError: + pass + + # ignore the log: + # "Exception during subprocess creation, kill the subprocess" + with test_utils.disable_logger(): + self.loop.run_until_complete(cancel_make_transport()) + test_utils.run_briefly(self.loop) + + def test_close_kill_running(self): + + async def kill_running(): + create = self.loop.subprocess_exec(asyncio.SubprocessProtocol, + *PROGRAM_BLOCKED) + transport, protocol = await create + + kill_called = False + def kill(): + nonlocal kill_called + kill_called = True + orig_kill() + + proc = transport.get_extra_info('subprocess') + orig_kill = proc.kill + proc.kill = kill + returncode = transport.get_returncode() + transport.close() + await asyncio.wait_for(transport._wait(), 5) + return (returncode, kill_called) + + # Ignore "Close running child process: kill ..." log + with test_utils.disable_logger(): + try: + returncode, killed = self.loop.run_until_complete( + kill_running() + ) + except asyncio.TimeoutError: + self.skipTest( + "Timeout failure on waiting for subprocess stopping" + ) + self.assertIsNone(returncode) + + # transport.close() must kill the process if it is still running + self.assertTrue(killed) + test_utils.run_briefly(self.loop) + + def test_close_dont_kill_finished(self): + + async def kill_running(): + create = self.loop.subprocess_exec(asyncio.SubprocessProtocol, + *PROGRAM_BLOCKED) + transport, protocol = await create + proc = transport.get_extra_info('subprocess') + + # kill the process (but asyncio is not notified immediately) + proc.kill() + proc.wait() + + proc.kill = mock.Mock() + proc_returncode = proc.poll() + transport_returncode = transport.get_returncode() + transport.close() + return (proc_returncode, transport_returncode, proc.kill.called) + + # Ignore "Unknown child process pid ..." log of SafeChildWatcher, + # emitted because the test already consumes the exit status: + # proc.wait() + with test_utils.disable_logger(): + result = self.loop.run_until_complete(kill_running()) + test_utils.run_briefly(self.loop) + + proc_returncode, transport_return_code, killed = result + + self.assertIsNotNone(proc_returncode) + self.assertIsNone(transport_return_code) + + # transport.close() must not kill the process if it finished, even if + # the transport was not notified yet + self.assertFalse(killed) + + # Unlike SafeChildWatcher, FastChildWatcher does not pop the + # callbacks if waitpid() is called elsewhere. Let's clear them + # manually to avoid a warning when the watcher is detached. + if (sys.platform != 'win32' and + isinstance(self, SubprocessFastWatcherTests)): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + asyncio.get_child_watcher()._callbacks.clear() + + async def _test_popen_error(self, stdin): + if sys.platform == 'win32': + target = 'asyncio.windows_utils.Popen' + else: + target = 'subprocess.Popen' + with mock.patch(target) as popen: + exc = ZeroDivisionError + popen.side_effect = exc + + with warnings.catch_warnings(record=True) as warns: + with self.assertRaises(exc): + await asyncio.create_subprocess_exec( + sys.executable, + '-c', + 'pass', + stdin=stdin + ) + self.assertEqual(warns, []) + + def test_popen_error(self): + # Issue #24763: check that the subprocess transport is closed + # when BaseSubprocessTransport fails + self.loop.run_until_complete(self._test_popen_error(stdin=None)) + + def test_popen_error_with_stdin_pipe(self): + # Issue #35721: check that newly created socket pair is closed when + # Popen fails + self.loop.run_until_complete( + self._test_popen_error(stdin=subprocess.PIPE)) + + def test_read_stdout_after_process_exit(self): + + async def execute(): + code = '\n'.join(['import sys', + 'for _ in range(64):', + ' sys.stdout.write("x" * 4096)', + 'sys.stdout.flush()', + 'sys.exit(1)']) + + process = await asyncio.create_subprocess_exec( + sys.executable, '-c', code, + stdout=asyncio.subprocess.PIPE, + ) + + while True: + data = await process.stdout.read(65536) + if data: + await asyncio.sleep(0.3) + else: + break + + self.loop.run_until_complete(execute()) + + def test_create_subprocess_exec_text_mode_fails(self): + async def execute(): + with self.assertRaises(ValueError): + await subprocess.create_subprocess_exec(sys.executable, + text=True) + + with self.assertRaises(ValueError): + await subprocess.create_subprocess_exec(sys.executable, + encoding="utf-8") + + with self.assertRaises(ValueError): + await subprocess.create_subprocess_exec(sys.executable, + errors="strict") + + self.loop.run_until_complete(execute()) + + def test_create_subprocess_shell_text_mode_fails(self): + + async def execute(): + with self.assertRaises(ValueError): + await subprocess.create_subprocess_shell(sys.executable, + text=True) + + with self.assertRaises(ValueError): + await subprocess.create_subprocess_shell(sys.executable, + encoding="utf-8") + + with self.assertRaises(ValueError): + await subprocess.create_subprocess_shell(sys.executable, + errors="strict") + + self.loop.run_until_complete(execute()) + + def test_create_subprocess_exec_with_path(self): + async def execute(): + p = await subprocess.create_subprocess_exec( + os_helper.FakePath(sys.executable), '-c', 'pass') + await p.wait() + p = await subprocess.create_subprocess_exec( + sys.executable, '-c', 'pass', os_helper.FakePath('.')) + await p.wait() + + self.assertIsNone(self.loop.run_until_complete(execute())) + + async def check_stdout_output(self, coro, output): + proc = await coro + stdout, _ = await proc.communicate() + self.assertEqual(stdout, output) + self.assertEqual(proc.returncode, 0) + task = asyncio.create_task(proc.wait()) + await asyncio.sleep(0) + self.assertEqual(task.result(), proc.returncode) + + def test_create_subprocess_env_shell(self) -> None: + async def main() -> None: + executable = sys.executable + if sys.platform == "win32": + executable = f'"{executable}"' + cmd = f'''{executable} -c "import os, sys; sys.stdout.write(os.getenv('FOO'))"''' + env = os.environ.copy() + env["FOO"] = "bar" + proc = await asyncio.create_subprocess_shell( + cmd, env=env, stdout=subprocess.PIPE + ) + return proc + + self.loop.run_until_complete(self.check_stdout_output(main(), b'bar')) + + def test_create_subprocess_env_exec(self) -> None: + async def main() -> None: + cmd = [sys.executable, "-c", + "import os, sys; sys.stdout.write(os.getenv('FOO'))"] + env = os.environ.copy() + env["FOO"] = "baz" + proc = await asyncio.create_subprocess_exec( + *cmd, env=env, stdout=subprocess.PIPE + ) + return proc + + self.loop.run_until_complete(self.check_stdout_output(main(), b'baz')) + + + def test_subprocess_concurrent_wait(self) -> None: + async def main() -> None: + proc = await asyncio.create_subprocess_exec( + *PROGRAM_CAT, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + stdout, _ = await proc.communicate(b'some data') + self.assertEqual(stdout, b"some data") + self.assertEqual(proc.returncode, 0) + self.assertEqual(await asyncio.gather(*[proc.wait() for _ in range(10)]), + [proc.returncode] * 10) + + self.loop.run_until_complete(main()) + + def test_subprocess_protocol_events(self): + # gh-108973: Test that all subprocess protocol methods are called. + # The protocol methods are not called in a determistic order. + # The order depends on the event loop and the operating system. + events = [] + fds = [1, 2] + expected = [ + ('pipe_data_received', 1, b'stdout'), + ('pipe_data_received', 2, b'stderr'), + ('pipe_connection_lost', 1), + ('pipe_connection_lost', 2), + 'process_exited', + ] + per_fd_expected = [ + 'pipe_data_received', + 'pipe_connection_lost', + ] + + class MyProtocol(asyncio.SubprocessProtocol): + def __init__(self, exit_future: asyncio.Future) -> None: + self.exit_future = exit_future + + def pipe_data_received(self, fd, data) -> None: + events.append(('pipe_data_received', fd, data)) + self.exit_maybe() + + def pipe_connection_lost(self, fd, exc) -> None: + events.append(('pipe_connection_lost', fd)) + self.exit_maybe() + + def process_exited(self) -> None: + events.append('process_exited') + self.exit_maybe() + + def exit_maybe(self): + # Only exit when we got all expected events + if len(events) >= len(expected): + self.exit_future.set_result(True) + + async def main() -> None: + loop = asyncio.get_running_loop() + exit_future = asyncio.Future() + code = 'import sys; sys.stdout.write("stdout"); sys.stderr.write("stderr")' + transport, _ = await loop.subprocess_exec(lambda: MyProtocol(exit_future), + sys.executable, '-c', code, stdin=None) + await exit_future + transport.close() + + return events + + events = self.loop.run_until_complete(main()) + + # First, make sure that we received all events + self.assertSetEqual(set(events), set(expected)) + + # Second, check order of pipe events per file descriptor + per_fd_events = {fd: [] for fd in fds} + for event in events: + if event == 'process_exited': + continue + name, fd = event[:2] + per_fd_events[fd].append(name) + + for fd in fds: + self.assertEqual(per_fd_events[fd], per_fd_expected, (fd, events)) + + def test_subprocess_communicate_stdout(self): + # See https://github.com/python/cpython/issues/100133 + async def get_command_stdout(cmd, *args): + proc = await asyncio.create_subprocess_exec( + cmd, *args, stdout=asyncio.subprocess.PIPE, + ) + stdout, _ = await proc.communicate() + return stdout.decode().strip() + + async def main(): + outputs = [f'foo{i}' for i in range(10)] + res = await asyncio.gather(*[get_command_stdout(sys.executable, '-c', + f'print({out!r})') for out in outputs]) + self.assertEqual(res, outputs) + + self.loop.run_until_complete(main()) + + @unittest.skipIf(sys.platform != 'linux', "Linux only") + @unittest.skip('TODO: RUSTPYTHON') + # Causing a hang + def test_subprocess_send_signal_race(self): + # See https://github.com/python/cpython/issues/87744 + async def main(): + for _ in range(10): + proc = await asyncio.create_subprocess_exec('sleep', '0.1') + await asyncio.sleep(0.1) + try: + proc.send_signal(signal.SIGUSR1) + except ProcessLookupError: + pass + self.assertNotEqual(await proc.wait(), 255) + + self.loop.run_until_complete(main()) + + +if sys.platform != 'win32': + # Unix + class SubprocessWatcherMixin(SubprocessMixin): + + Watcher = None + + def setUp(self): + super().setUp() + policy = asyncio.get_event_loop_policy() + self.loop = policy.new_event_loop() + self.set_event_loop(self.loop) + + watcher = self._get_watcher() + watcher.attach_loop(self.loop) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + policy.set_child_watcher(watcher) + + def tearDown(self): + super().tearDown() + policy = asyncio.get_event_loop_policy() + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + watcher = policy.get_child_watcher() + policy.set_child_watcher(None) + watcher.attach_loop(None) + watcher.close() + + class SubprocessThreadedWatcherTests(SubprocessWatcherMixin, + test_utils.TestCase): + + def _get_watcher(self): + return unix_events.ThreadedChildWatcher() + + class SubprocessSafeWatcherTests(SubprocessWatcherMixin, + test_utils.TestCase): + + def _get_watcher(self): + with self.assertWarns(DeprecationWarning): + return unix_events.SafeChildWatcher() + + class MultiLoopChildWatcherTests(test_utils.TestCase): + + def test_warns(self): + with self.assertWarns(DeprecationWarning): + unix_events.MultiLoopChildWatcher() + + class SubprocessFastWatcherTests(SubprocessWatcherMixin, + test_utils.TestCase): + + def _get_watcher(self): + with self.assertWarns(DeprecationWarning): + return unix_events.FastChildWatcher() + + @unittest.skipUnless( + unix_events.can_use_pidfd(), + "operating system does not support pidfds", + ) + class SubprocessPidfdWatcherTests(SubprocessWatcherMixin, + test_utils.TestCase): + + def _get_watcher(self): + return unix_events.PidfdChildWatcher() + + + class GenericWatcherTests(test_utils.TestCase): + + def test_create_subprocess_fails_with_inactive_watcher(self): + watcher = mock.create_autospec(asyncio.AbstractChildWatcher) + watcher.is_active.return_value = False + + async def execute(): + asyncio.set_child_watcher(watcher) + + with self.assertRaises(RuntimeError): + await subprocess.create_subprocess_exec( + os_helper.FakePath(sys.executable), '-c', 'pass') + + watcher.add_child_handler.assert_not_called() + + with asyncio.Runner(loop_factory=asyncio.new_event_loop) as runner: + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + self.assertIsNone(runner.run(execute())) + self.assertListEqual(watcher.mock_calls, [ + mock.call.__enter__(), + mock.call.is_active(), + mock.call.__exit__(RuntimeError, mock.ANY, mock.ANY), + ], watcher.mock_calls) + + + @unittest.skipUnless( + unix_events.can_use_pidfd(), + "operating system does not support pidfds", + ) + def test_create_subprocess_with_pidfd(self): + async def in_thread(): + proc = await asyncio.create_subprocess_exec( + *PROGRAM_CAT, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + stdout, stderr = await proc.communicate(b"some data") + return proc.returncode, stdout + + async def main(): + # asyncio.Runner did not call asyncio.set_event_loop() + with warnings.catch_warnings(): + warnings.simplefilter('error', DeprecationWarning) + # get_event_loop() raises DeprecationWarning if + # set_event_loop() was never called and RuntimeError if + # it was called at least once. + with self.assertRaises((RuntimeError, DeprecationWarning)): + asyncio.get_event_loop_policy().get_event_loop() + return await asyncio.to_thread(asyncio.run, in_thread()) + with self.assertWarns(DeprecationWarning): + asyncio.set_child_watcher(asyncio.PidfdChildWatcher()) + try: + with asyncio.Runner(loop_factory=asyncio.new_event_loop) as runner: + returncode, stdout = runner.run(main()) + self.assertEqual(returncode, 0) + self.assertEqual(stdout, b'some data') + finally: + with self.assertWarns(DeprecationWarning): + asyncio.set_child_watcher(None) +else: + # Windows + class SubprocessProactorTests(SubprocessMixin, test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = asyncio.ProactorEventLoop() + self.set_event_loop(self.loop) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py new file mode 100644 index 00000000000..b3ec298f18c --- /dev/null +++ b/Lib/test/test_asyncio/test_taskgroups.py @@ -0,0 +1,1123 @@ +# Adapted with permission from the EdgeDB project; +# license: PSFL. + +import weakref +import sys +import gc +import asyncio +import contextvars +import contextlib +from asyncio import taskgroups +import unittest +import warnings + +from test.test_asyncio.utils import await_without_task + +# To prevent a warning "test altered the execution environment" +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class MyExc(Exception): + pass + + +class MyBaseExc(BaseException): + pass + + +def get_error_types(eg): + return {type(exc) for exc in eg.exceptions} + + +def set_gc_state(enabled): + was_enabled = gc.isenabled() + if enabled: + gc.enable() + else: + gc.disable() + return was_enabled + + +@contextlib.contextmanager +def disable_gc(): + was_enabled = set_gc_state(enabled=False) + try: + yield + finally: + set_gc_state(enabled=was_enabled) + + +class BaseTestTaskGroup: + + async def test_taskgroup_01(self): + + async def foo1(): + await asyncio.sleep(0.1) + return 42 + + async def foo2(): + await asyncio.sleep(0.2) + return 11 + + async with taskgroups.TaskGroup() as g: + t1 = g.create_task(foo1()) + t2 = g.create_task(foo2()) + + self.assertEqual(t1.result(), 42) + self.assertEqual(t2.result(), 11) + + async def test_taskgroup_02(self): + + async def foo1(): + await asyncio.sleep(0.1) + return 42 + + async def foo2(): + await asyncio.sleep(0.2) + return 11 + + async with taskgroups.TaskGroup() as g: + t1 = g.create_task(foo1()) + await asyncio.sleep(0.15) + t2 = g.create_task(foo2()) + + self.assertEqual(t1.result(), 42) + self.assertEqual(t2.result(), 11) + + async def test_taskgroup_03(self): + + async def foo1(): + await asyncio.sleep(1) + return 42 + + async def foo2(): + await asyncio.sleep(0.2) + return 11 + + async with taskgroups.TaskGroup() as g: + t1 = g.create_task(foo1()) + await asyncio.sleep(0.15) + # cancel t1 explicitly, i.e. everything should continue + # working as expected. + t1.cancel() + + t2 = g.create_task(foo2()) + + self.assertTrue(t1.cancelled()) + self.assertEqual(t2.result(), 11) + + async def test_taskgroup_04(self): + + NUM = 0 + t2_cancel = False + t2 = None + + async def foo1(): + await asyncio.sleep(0.1) + 1 / 0 + + async def foo2(): + nonlocal NUM, t2_cancel + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + t2_cancel = True + raise + NUM += 1 + + async def runner(): + nonlocal NUM, t2 + + async with taskgroups.TaskGroup() as g: + g.create_task(foo1()) + t2 = g.create_task(foo2()) + + NUM += 10 + + with self.assertRaises(ExceptionGroup) as cm: + await asyncio.create_task(runner()) + + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + + self.assertEqual(NUM, 0) + self.assertTrue(t2_cancel) + self.assertTrue(t2.cancelled()) + + async def test_cancel_children_on_child_error(self): + # When a child task raises an error, the rest of the children + # are cancelled and the errors are gathered into an EG. + + NUM = 0 + t2_cancel = False + runner_cancel = False + + async def foo1(): + await asyncio.sleep(0.1) + 1 / 0 + + async def foo2(): + nonlocal NUM, t2_cancel + try: + await asyncio.sleep(5) + except asyncio.CancelledError: + t2_cancel = True + raise + NUM += 1 + + async def runner(): + nonlocal NUM, runner_cancel + + async with taskgroups.TaskGroup() as g: + g.create_task(foo1()) + g.create_task(foo1()) + g.create_task(foo1()) + g.create_task(foo2()) + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + runner_cancel = True + raise + + NUM += 10 + + # The 3 foo1 sub tasks can be racy when the host is busy - if the + # cancellation happens in the middle, we'll see partial sub errors here + with self.assertRaises(ExceptionGroup) as cm: + await asyncio.create_task(runner()) + + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + self.assertEqual(NUM, 0) + self.assertTrue(t2_cancel) + self.assertTrue(runner_cancel) + + async def test_cancellation(self): + + NUM = 0 + + async def foo(): + nonlocal NUM + try: + await asyncio.sleep(5) + except asyncio.CancelledError: + NUM += 1 + raise + + async def runner(): + async with taskgroups.TaskGroup() as g: + for _ in range(5): + g.create_task(foo()) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(asyncio.CancelledError) as cm: + await r + + self.assertEqual(NUM, 5) + + async def test_taskgroup_07(self): + + NUM = 0 + + async def foo(): + nonlocal NUM + try: + await asyncio.sleep(5) + except asyncio.CancelledError: + NUM += 1 + raise + + async def runner(): + nonlocal NUM + async with taskgroups.TaskGroup() as g: + for _ in range(5): + g.create_task(foo()) + + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + NUM += 10 + raise + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(asyncio.CancelledError): + await r + + self.assertEqual(NUM, 15) + + async def test_taskgroup_08(self): + + async def foo(): + try: + await asyncio.sleep(10) + finally: + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup() as g: + for _ in range(5): + g.create_task(foo()) + + await asyncio.sleep(10) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(ExceptionGroup) as cm: + await r + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + + async def test_taskgroup_09(self): + + t1 = t2 = None + + async def foo1(): + await asyncio.sleep(1) + return 42 + + async def foo2(): + await asyncio.sleep(2) + return 11 + + async def runner(): + nonlocal t1, t2 + async with taskgroups.TaskGroup() as g: + t1 = g.create_task(foo1()) + t2 = g.create_task(foo2()) + await asyncio.sleep(0.1) + 1 / 0 + + try: + await runner() + except ExceptionGroup as t: + self.assertEqual(get_error_types(t), {ZeroDivisionError}) + else: + self.fail('ExceptionGroup was not raised') + + self.assertTrue(t1.cancelled()) + self.assertTrue(t2.cancelled()) + + async def test_taskgroup_10(self): + + t1 = t2 = None + + async def foo1(): + await asyncio.sleep(1) + return 42 + + async def foo2(): + await asyncio.sleep(2) + return 11 + + async def runner(): + nonlocal t1, t2 + async with taskgroups.TaskGroup() as g: + t1 = g.create_task(foo1()) + t2 = g.create_task(foo2()) + 1 / 0 + + try: + await runner() + except ExceptionGroup as t: + self.assertEqual(get_error_types(t), {ZeroDivisionError}) + else: + self.fail('ExceptionGroup was not raised') + + self.assertTrue(t1.cancelled()) + self.assertTrue(t2.cancelled()) + + async def test_taskgroup_11(self): + + async def foo(): + try: + await asyncio.sleep(10) + finally: + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup(): + async with taskgroups.TaskGroup() as g2: + for _ in range(5): + g2.create_task(foo()) + + await asyncio.sleep(10) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(ExceptionGroup) as cm: + await r + + self.assertEqual(get_error_types(cm.exception), {ExceptionGroup}) + self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError}) + + async def test_taskgroup_12(self): + + async def foo(): + try: + await asyncio.sleep(10) + finally: + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup() as g1: + g1.create_task(asyncio.sleep(10)) + + async with taskgroups.TaskGroup() as g2: + for _ in range(5): + g2.create_task(foo()) + + await asyncio.sleep(10) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(ExceptionGroup) as cm: + await r + + self.assertEqual(get_error_types(cm.exception), {ExceptionGroup}) + self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError}) + + async def test_taskgroup_13(self): + + async def crash_after(t): + await asyncio.sleep(t) + raise ValueError(t) + + async def runner(): + async with taskgroups.TaskGroup() as g1: + g1.create_task(crash_after(0.1)) + + async with taskgroups.TaskGroup() as g2: + g2.create_task(crash_after(10)) + + r = asyncio.create_task(runner()) + with self.assertRaises(ExceptionGroup) as cm: + await r + + self.assertEqual(get_error_types(cm.exception), {ValueError}) + + async def test_taskgroup_14(self): + + async def crash_after(t): + await asyncio.sleep(t) + raise ValueError(t) + + async def runner(): + async with taskgroups.TaskGroup() as g1: + g1.create_task(crash_after(10)) + + async with taskgroups.TaskGroup() as g2: + g2.create_task(crash_after(0.1)) + + r = asyncio.create_task(runner()) + with self.assertRaises(ExceptionGroup) as cm: + await r + + self.assertEqual(get_error_types(cm.exception), {ExceptionGroup}) + self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ValueError}) + + async def test_taskgroup_15(self): + + async def crash_soon(): + await asyncio.sleep(0.3) + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup() as g1: + g1.create_task(crash_soon()) + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + await asyncio.sleep(0.5) + raise + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(ExceptionGroup) as cm: + await r + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + + async def test_taskgroup_16(self): + + async def crash_soon(): + await asyncio.sleep(0.3) + 1 / 0 + + async def nested_runner(): + async with taskgroups.TaskGroup() as g1: + g1.create_task(crash_soon()) + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + await asyncio.sleep(0.5) + raise + + async def runner(): + t = asyncio.create_task(nested_runner()) + await t + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(ExceptionGroup) as cm: + await r + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + + async def test_taskgroup_17(self): + NUM = 0 + + async def runner(): + nonlocal NUM + async with taskgroups.TaskGroup(): + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + NUM += 10 + raise + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(asyncio.CancelledError): + await r + + self.assertEqual(NUM, 10) + + async def test_taskgroup_18(self): + NUM = 0 + + async def runner(): + nonlocal NUM + async with taskgroups.TaskGroup(): + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + NUM += 10 + # This isn't a good idea, but we have to support + # this weird case. + raise MyExc + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + + try: + await r + except ExceptionGroup as t: + self.assertEqual(get_error_types(t),{MyExc}) + else: + self.fail('ExceptionGroup was not raised') + + self.assertEqual(NUM, 10) + + async def test_taskgroup_19(self): + async def crash_soon(): + await asyncio.sleep(0.1) + 1 / 0 + + async def nested(): + try: + await asyncio.sleep(10) + finally: + raise MyExc + + async def runner(): + async with taskgroups.TaskGroup() as g: + g.create_task(crash_soon()) + await nested() + + r = asyncio.create_task(runner()) + try: + await r + except ExceptionGroup as t: + self.assertEqual(get_error_types(t), {MyExc, ZeroDivisionError}) + else: + self.fail('TasgGroupError was not raised') + + async def test_taskgroup_20(self): + async def crash_soon(): + await asyncio.sleep(0.1) + 1 / 0 + + async def nested(): + try: + await asyncio.sleep(10) + finally: + raise KeyboardInterrupt + + async def runner(): + async with taskgroups.TaskGroup() as g: + g.create_task(crash_soon()) + await nested() + + with self.assertRaises(KeyboardInterrupt): + await runner() + + async def test_taskgroup_20a(self): + async def crash_soon(): + await asyncio.sleep(0.1) + 1 / 0 + + async def nested(): + try: + await asyncio.sleep(10) + finally: + raise MyBaseExc + + async def runner(): + async with taskgroups.TaskGroup() as g: + g.create_task(crash_soon()) + await nested() + + with self.assertRaises(BaseExceptionGroup) as cm: + await runner() + + self.assertEqual( + get_error_types(cm.exception), {MyBaseExc, ZeroDivisionError} + ) + + async def _test_taskgroup_21(self): + # This test doesn't work as asyncio, currently, doesn't + # correctly propagate KeyboardInterrupt (or SystemExit) -- + # those cause the event loop itself to crash. + # (Compare to the previous (passing) test -- that one raises + # a plain exception but raises KeyboardInterrupt in nested(); + # this test does it the other way around.) + + async def crash_soon(): + await asyncio.sleep(0.1) + raise KeyboardInterrupt + + async def nested(): + try: + await asyncio.sleep(10) + finally: + raise TypeError + + async def runner(): + async with taskgroups.TaskGroup() as g: + g.create_task(crash_soon()) + await nested() + + with self.assertRaises(KeyboardInterrupt): + await runner() + + async def test_taskgroup_21a(self): + + async def crash_soon(): + await asyncio.sleep(0.1) + raise MyBaseExc + + async def nested(): + try: + await asyncio.sleep(10) + finally: + raise TypeError + + async def runner(): + async with taskgroups.TaskGroup() as g: + g.create_task(crash_soon()) + await nested() + + with self.assertRaises(BaseExceptionGroup) as cm: + await runner() + + self.assertEqual(get_error_types(cm.exception), {MyBaseExc, TypeError}) + + async def test_taskgroup_22(self): + + async def foo1(): + await asyncio.sleep(1) + return 42 + + async def foo2(): + await asyncio.sleep(2) + return 11 + + async def runner(): + async with taskgroups.TaskGroup() as g: + g.create_task(foo1()) + g.create_task(foo2()) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.05) + r.cancel() + + with self.assertRaises(asyncio.CancelledError): + await r + + async def test_taskgroup_23(self): + + async def do_job(delay): + await asyncio.sleep(delay) + + async with taskgroups.TaskGroup() as g: + for count in range(10): + await asyncio.sleep(0.1) + g.create_task(do_job(0.3)) + if count == 5: + self.assertLess(len(g._tasks), 5) + await asyncio.sleep(1.35) + self.assertEqual(len(g._tasks), 0) + + async def test_taskgroup_24(self): + + async def root(g): + await asyncio.sleep(0.1) + g.create_task(coro1(0.1)) + g.create_task(coro1(0.2)) + + async def coro1(delay): + await asyncio.sleep(delay) + + async def runner(): + async with taskgroups.TaskGroup() as g: + g.create_task(root(g)) + + await runner() + + async def test_taskgroup_25(self): + nhydras = 0 + + async def hydra(g): + nonlocal nhydras + nhydras += 1 + await asyncio.sleep(0.01) + g.create_task(hydra(g)) + g.create_task(hydra(g)) + + async def hercules(): + while nhydras < 10: + await asyncio.sleep(0.015) + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup() as g: + g.create_task(hydra(g)) + g.create_task(hercules()) + + with self.assertRaises(ExceptionGroup) as cm: + await runner() + + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + self.assertGreaterEqual(nhydras, 10) + + async def test_taskgroup_task_name(self): + async def coro(): + await asyncio.sleep(0) + async with taskgroups.TaskGroup() as g: + t = g.create_task(coro(), name="yolo") + self.assertEqual(t.get_name(), "yolo") + + async def test_taskgroup_task_context(self): + cvar = contextvars.ContextVar('cvar') + + async def coro(val): + await asyncio.sleep(0) + cvar.set(val) + + async with taskgroups.TaskGroup() as g: + ctx = contextvars.copy_context() + self.assertIsNone(ctx.get(cvar)) + t1 = g.create_task(coro(1), context=ctx) + await t1 + self.assertEqual(1, ctx.get(cvar)) + t2 = g.create_task(coro(2), context=ctx) + await t2 + self.assertEqual(2, ctx.get(cvar)) + + async def test_taskgroup_no_create_task_after_failure(self): + async def coro1(): + await asyncio.sleep(0.001) + 1 / 0 + async def coro2(g): + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + with self.assertRaises(RuntimeError): + g.create_task(coro1()) + + with self.assertRaises(ExceptionGroup) as cm: + async with taskgroups.TaskGroup() as g: + g.create_task(coro1()) + g.create_task(coro2(g)) + + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + + async def test_taskgroup_context_manager_exit_raises(self): + # See https://github.com/python/cpython/issues/95289 + class CustomException(Exception): + pass + + async def raise_exc(): + raise CustomException + + @contextlib.asynccontextmanager + async def database(): + try: + yield + finally: + raise CustomException + + async def main(): + task = asyncio.current_task() + try: + async with taskgroups.TaskGroup() as tg: + async with database(): + tg.create_task(raise_exc()) + await asyncio.sleep(1) + except* CustomException as err: + self.assertEqual(task.cancelling(), 0) + self.assertEqual(len(err.exceptions), 2) + + else: + self.fail('CustomException not raised') + + await asyncio.create_task(main()) + + async def test_taskgroup_already_entered(self): + tg = taskgroups.TaskGroup() + async with tg: + with self.assertRaisesRegex(RuntimeError, "has already been entered"): + async with tg: + pass + + async def test_taskgroup_double_enter(self): + tg = taskgroups.TaskGroup() + async with tg: + pass + with self.assertRaisesRegex(RuntimeError, "has already been entered"): + async with tg: + pass + + async def test_taskgroup_finished(self): + async def create_task_after_tg_finish(): + tg = taskgroups.TaskGroup() + async with tg: + pass + coro = asyncio.sleep(0) + with self.assertRaisesRegex(RuntimeError, "is finished"): + tg.create_task(coro) + + # Make sure the coroutine was closed when submitted to the inactive tg + # (if not closed, a RuntimeWarning should have been raised) + with warnings.catch_warnings(record=True) as w: + await create_task_after_tg_finish() + self.assertEqual(len(w), 0) + + async def test_taskgroup_not_entered(self): + tg = taskgroups.TaskGroup() + coro = asyncio.sleep(0) + with self.assertRaisesRegex(RuntimeError, "has not been entered"): + tg.create_task(coro) + + async def test_taskgroup_without_parent_task(self): + tg = taskgroups.TaskGroup() + with self.assertRaisesRegex(RuntimeError, "parent task"): + await await_without_task(tg.__aenter__()) + coro = asyncio.sleep(0) + with self.assertRaisesRegex(RuntimeError, "has not been entered"): + tg.create_task(coro) + + async def test_coro_closed_when_tg_closed(self): + async def run_coro_after_tg_closes(): + async with taskgroups.TaskGroup() as tg: + pass + coro = asyncio.sleep(0) + with self.assertRaisesRegex(RuntimeError, "is finished"): + tg.create_task(coro) + + await run_coro_after_tg_closes() + + + # TODO: RUSTPYTHON + # AssertionError: 1 != 0 + @unittest.expectedFailure + async def test_cancelling_level_preserved(self): + async def raise_after(t, e): + await asyncio.sleep(t) + raise e() + + try: + async with asyncio.TaskGroup() as tg: + tg.create_task(raise_after(0.0, RuntimeError)) + except* RuntimeError: + pass + self.assertEqual(asyncio.current_task().cancelling(), 0) + + # TODO: RUSTPYTHON + # AssertionError: 1 != 0 + @unittest.expectedFailure + async def test_nested_groups_both_cancelled(self): + async def raise_after(t, e): + await asyncio.sleep(t) + raise e() + + try: + async with asyncio.TaskGroup() as outer_tg: + try: + async with asyncio.TaskGroup() as inner_tg: + inner_tg.create_task(raise_after(0, RuntimeError)) + outer_tg.create_task(raise_after(0, ValueError)) + except* RuntimeError: + pass + else: + self.fail("RuntimeError not raised") + self.assertEqual(asyncio.current_task().cancelling(), 1) + except* ValueError: + pass + else: + self.fail("ValueError not raised") + self.assertEqual(asyncio.current_task().cancelling(), 0) + + + # TODO: RUSTPYTHON + # AssertionError: Sleep after group should have been cancelled + @unittest.expectedFailure + async def test_error_and_cancel(self): + event = asyncio.Event() + + async def raise_error(): + event.set() + await asyncio.sleep(0) + raise RuntimeError() + + async def inner(): + try: + async with taskgroups.TaskGroup() as tg: + tg.create_task(raise_error()) + await asyncio.sleep(1) + self.fail("Sleep in group should have been cancelled") + except* RuntimeError: + self.assertEqual(asyncio.current_task().cancelling(), 1) + self.assertEqual(asyncio.current_task().cancelling(), 1) + await asyncio.sleep(1) + self.fail("Sleep after group should have been cancelled") + + async def outer(): + t = asyncio.create_task(inner()) + await event.wait() + self.assertEqual(t.cancelling(), 0) + t.cancel() + self.assertEqual(t.cancelling(), 1) + with self.assertRaises(asyncio.CancelledError): + await t + self.assertTrue(t.cancelled()) + + await outer() + + + @unittest.skip('TODO: RUSTPYTHON') + # NotImplementedError + async def test_exception_refcycles_direct(self): + """Test that TaskGroup doesn't keep a reference to the raised ExceptionGroup""" + tg = asyncio.TaskGroup() + exc = None + + class _Done(Exception): + pass + + try: + async with tg: + raise _Done + except ExceptionGroup as e: + exc = e + + self.assertIsNotNone(exc) + self.assertListEqual(gc.get_referrers(exc), []) + + + @unittest.skip('TODO: RUSTPYTHON') + # NotImplementedError + async def test_exception_refcycles_errors(self): + """Test that TaskGroup deletes self._errors, and __aexit__ args""" + tg = asyncio.TaskGroup() + exc = None + + class _Done(Exception): + pass + + try: + async with tg: + raise _Done + except* _Done as excs: + exc = excs.exceptions[0] + + self.assertIsInstance(exc, _Done) + self.assertListEqual(gc.get_referrers(exc), []) + + + @unittest.skip('TODO: RUSTPYTHON') + # NotImplementedError + async def test_exception_refcycles_parent_task(self): + """Test that TaskGroup deletes self._parent_task""" + tg = asyncio.TaskGroup() + exc = None + + class _Done(Exception): + pass + + async def coro_fn(): + async with tg: + raise _Done + + try: + async with asyncio.TaskGroup() as tg2: + tg2.create_task(coro_fn()) + except* _Done as excs: + exc = excs.exceptions[0].exceptions[0] + + self.assertIsInstance(exc, _Done) + self.assertListEqual(gc.get_referrers(exc), []) + + + @unittest.skip('TODO: RUSTPYTHON') + # NotImplementedError + async def test_exception_refcycles_parent_task_wr(self): + """Test that TaskGroup deletes self._parent_task and create_task() deletes task""" + tg = asyncio.TaskGroup() + exc = None + + class _Done(Exception): + pass + + async def coro_fn(): + async with tg: + raise _Done + + with disable_gc(): + try: + async with asyncio.TaskGroup() as tg2: + task_wr = weakref.ref(tg2.create_task(coro_fn())) + except* _Done as excs: + exc = excs.exceptions[0].exceptions[0] + + self.assertIsNone(task_wr()) + self.assertIsInstance(exc, _Done) + self.assertListEqual(gc.get_referrers(exc), []) + + + @unittest.skip('TODO: RUSTPYTHON') + # NotImplementedError + async def test_exception_refcycles_propagate_cancellation_error(self): + """Test that TaskGroup deletes propagate_cancellation_error""" + tg = asyncio.TaskGroup() + exc = None + + try: + async with asyncio.timeout(-1): + async with tg: + await asyncio.sleep(0) + except TimeoutError as e: + exc = e.__cause__ + + self.assertIsInstance(exc, asyncio.CancelledError) + self.assertListEqual(gc.get_referrers(exc), []) + + @unittest.skip('TODO: RUSTPYTHON') + # NotImplementedError + async def test_exception_refcycles_base_error(self): + """Test that TaskGroup deletes self._base_error""" + class MyKeyboardInterrupt(KeyboardInterrupt): + pass + + tg = asyncio.TaskGroup() + exc = None + + try: + async with tg: + raise MyKeyboardInterrupt + except MyKeyboardInterrupt as e: + exc = e + + self.assertIsNotNone(exc) + self.assertListEqual(gc.get_referrers(exc), []) + + + async def test_cancels_task_if_created_during_creation(self): + # regression test for gh-128550 + ran = False + class MyError(Exception): + pass + + exc = None + try: + async with asyncio.TaskGroup() as tg: + async def third_task(): + raise MyError("third task failed") + + async def second_task(): + nonlocal ran + tg.create_task(third_task()) + with self.assertRaises(asyncio.CancelledError): + await asyncio.sleep(0) # eager tasks cancel here + await asyncio.sleep(0) # lazy tasks cancel here + ran = True + + tg.create_task(second_task()) + except* MyError as excs: + exc = excs.exceptions[0] + + self.assertTrue(ran) + self.assertIsInstance(exc, MyError) + + + async def test_cancellation_does_not_leak_out_of_tg(self): + class MyError(Exception): + pass + + async def throw_error(): + raise MyError + + try: + async with asyncio.TaskGroup() as tg: + tg.create_task(throw_error()) + except* MyError: + pass + else: + self.fail("should have raised one MyError in group") + + # if this test fails this current task will be cancelled + # outside the task group and inside unittest internals + # we yield to the event loop with sleep(0) so that + # cancellation happens here and error is more understandable + await asyncio.sleep(0) + + +# TODO: RUSTPYTHON +# class TestTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase): +# loop_factory = asyncio.EventLoop + +class TestEagerTaskTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase): + @staticmethod + def loop_factory(): + loop = asyncio.EventLoop() + loop.set_task_factory(asyncio.eager_task_factory) + return loop + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py new file mode 100644 index 00000000000..325cd5833c3 --- /dev/null +++ b/Lib/test/test_asyncio/test_tasks.py @@ -0,0 +1,3603 @@ +"""Tests for tasks.py.""" + +import collections +import contextlib +import contextvars +import gc +import io +import random +import re +import sys +import traceback +import types +import unittest +from unittest import mock +from types import GenericAlias + +import asyncio +from asyncio import futures +from asyncio import tasks +from test.test_asyncio import utils as test_utils +from test import support +from test.support.script_helper import assert_python_ok + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +async def coroutine_function(): + pass + + +def format_coroutine(qualname, state, src, source_traceback, generator=False): + if generator: + state = '%s' % state + else: + state = '%s, defined' % state + if source_traceback is not None: + frame = source_traceback[-1] + return ('coro=<%s() %s at %s> created at %s:%s' + % (qualname, state, src, frame[0], frame[1])) + else: + return 'coro=<%s() %s at %s>' % (qualname, state, src) + + +def get_innermost_context(exc): + """ + Return information about the innermost exception context in the chain. + """ + depth = 0 + while True: + context = exc.__context__ + if context is None: + break + + exc = context + depth += 1 + + return (type(exc), exc.args, depth) + + +class Dummy: + + def __repr__(self): + return '' + + def __call__(self, *args): + pass + + +class CoroLikeObject: + def send(self, v): + raise StopIteration(42) + + def throw(self, *exc): + pass + + def close(self): + pass + + def __await__(self): + return self + + +class BaseTaskTests: + + Task = None + Future = None + + def new_task(self, loop, coro, name='TestTask', context=None): + return self.__class__.Task(coro, loop=loop, name=name, context=context) + + def new_future(self, loop): + return self.__class__.Future(loop=loop) + + def setUp(self): + super().setUp() + self.loop = self.new_test_loop() + self.loop.set_task_factory(self.new_task) + self.loop.create_future = lambda: self.new_future(self.loop) + + def test_generic_alias(self): + task = self.__class__.Task[str] + self.assertEqual(task.__args__, (str,)) + self.assertIsInstance(task, GenericAlias) + + def test_task_cancel_message_getter(self): + async def coro(): + pass + t = self.new_task(self.loop, coro()) + self.assertTrue(hasattr(t, '_cancel_message')) + self.assertEqual(t._cancel_message, None) + + t.cancel('my message') + self.assertEqual(t._cancel_message, 'my message') + + with self.assertRaises(asyncio.CancelledError) as cm: + self.loop.run_until_complete(t) + + self.assertEqual('my message', cm.exception.args[0]) + + def test_task_cancel_message_setter(self): + async def coro(): + pass + t = self.new_task(self.loop, coro()) + t.cancel('my message') + t._cancel_message = 'my new message' + self.assertEqual(t._cancel_message, 'my new message') + + with self.assertRaises(asyncio.CancelledError) as cm: + self.loop.run_until_complete(t) + + self.assertEqual('my new message', cm.exception.args[0]) + + def test_task_del_collect(self): + class Evil: + def __del__(self): + gc.collect() + + async def run(): + return Evil() + + self.loop.run_until_complete( + asyncio.gather(*[ + self.new_task(self.loop, run()) for _ in range(100) + ])) + + def test_other_loop_future(self): + other_loop = asyncio.new_event_loop() + fut = self.new_future(other_loop) + + async def run(fut): + await fut + + try: + with self.assertRaisesRegex(RuntimeError, + r'Task .* got Future .* attached'): + self.loop.run_until_complete(run(fut)) + finally: + other_loop.close() + + def test_task_awaits_on_itself(self): + + async def test(): + await task + + task = asyncio.ensure_future(test(), loop=self.loop) + + with self.assertRaisesRegex(RuntimeError, + 'Task cannot await on itself'): + self.loop.run_until_complete(task) + + def test_task_class(self): + async def notmuch(): + return 'ok' + t = self.new_task(self.loop, notmuch()) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._loop, self.loop) + self.assertIs(t.get_loop(), self.loop) + + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + t = self.new_task(loop, notmuch()) + self.assertIs(t._loop, loop) + loop.run_until_complete(t) + loop.close() + + def test_ensure_future_coroutine(self): + async def notmuch(): + return 'ok' + t = asyncio.ensure_future(notmuch(), loop=self.loop) + self.assertIs(t._loop, self.loop) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + + a = notmuch() + self.addCleanup(a.close) + with self.assertRaisesRegex(RuntimeError, 'no current event loop'): + asyncio.ensure_future(a) + + async def test(): + return asyncio.ensure_future(notmuch()) + t = self.loop.run_until_complete(test()) + self.assertIs(t._loop, self.loop) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + + # Deprecated in 3.10, undeprecated in 3.12 + asyncio.set_event_loop(self.loop) + self.addCleanup(asyncio.set_event_loop, None) + t = asyncio.ensure_future(notmuch()) + self.assertIs(t._loop, self.loop) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + + def test_ensure_future_future(self): + f_orig = self.new_future(self.loop) + f_orig.set_result('ko') + + f = asyncio.ensure_future(f_orig) + self.loop.run_until_complete(f) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 'ko') + self.assertIs(f, f_orig) + + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + + with self.assertRaises(ValueError): + f = asyncio.ensure_future(f_orig, loop=loop) + + loop.close() + + f = asyncio.ensure_future(f_orig, loop=self.loop) + self.assertIs(f, f_orig) + + def test_ensure_future_task(self): + async def notmuch(): + return 'ok' + t_orig = self.new_task(self.loop, notmuch()) + t = asyncio.ensure_future(t_orig) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t, t_orig) + + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + + with self.assertRaises(ValueError): + t = asyncio.ensure_future(t_orig, loop=loop) + + loop.close() + + t = asyncio.ensure_future(t_orig, loop=self.loop) + self.assertIs(t, t_orig) + + def test_ensure_future_awaitable(self): + class Aw: + def __init__(self, coro): + self.coro = coro + def __await__(self): + return self.coro.__await__() + + async def coro(): + return 'ok' + + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + fut = asyncio.ensure_future(Aw(coro()), loop=loop) + loop.run_until_complete(fut) + self.assertEqual(fut.result(), 'ok') + + def test_ensure_future_task_awaitable(self): + class Aw: + def __await__(self): + return asyncio.sleep(0, result='ok').__await__() + + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + task = asyncio.ensure_future(Aw(), loop=loop) + loop.run_until_complete(task) + self.assertTrue(task.done()) + self.assertEqual(task.result(), 'ok') + self.assertIsInstance(task.get_coro(), types.CoroutineType) + loop.close() + + def test_ensure_future_neither(self): + with self.assertRaises(TypeError): + asyncio.ensure_future('ok') + + def test_ensure_future_error_msg(self): + loop = asyncio.new_event_loop() + f = self.new_future(self.loop) + with self.assertRaisesRegex(ValueError, 'The future belongs to a ' + 'different loop than the one specified as ' + 'the loop argument'): + asyncio.ensure_future(f, loop=loop) + loop.close() + + def test_get_stack(self): + T = None + + async def foo(): + await bar() + + async def bar(): + # test get_stack() + f = T.get_stack(limit=1) + try: + self.assertEqual(f[0].f_code.co_name, 'foo') + finally: + f = None + + # test print_stack() + file = io.StringIO() + T.print_stack(limit=1, file=file) + file.seek(0) + tb = file.read() + self.assertRegex(tb, r'foo\(\) running') + + async def runner(): + nonlocal T + T = asyncio.ensure_future(foo(), loop=self.loop) + await T + + self.loop.run_until_complete(runner()) + + def test_task_repr(self): + self.loop.set_debug(False) + + async def notmuch(): + return 'abc' + + # test coroutine function + self.assertEqual(notmuch.__name__, 'notmuch') + self.assertRegex(notmuch.__qualname__, + r'\w+.test_task_repr..notmuch') + self.assertEqual(notmuch.__module__, __name__) + + filename, lineno = test_utils.get_function_source(notmuch) + src = "%s:%s" % (filename, lineno) + + # test coroutine object + gen = notmuch() + coro_qualname = 'BaseTaskTests.test_task_repr..notmuch' + self.assertEqual(gen.__name__, 'notmuch') + self.assertEqual(gen.__qualname__, coro_qualname) + + # test pending Task + t = self.new_task(self.loop, gen) + t.add_done_callback(Dummy()) + + coro = format_coroutine(coro_qualname, 'running', src, + t._source_traceback, generator=True) + self.assertEqual(repr(t), + "()]>" % coro) + + # test cancelling Task + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), + "()]>" % coro) + + # test cancelled Task + self.assertRaises(asyncio.CancelledError, + self.loop.run_until_complete, t) + coro = format_coroutine(coro_qualname, 'done', src, + t._source_traceback) + self.assertEqual(repr(t), + "" % coro) + + # test finished Task + t = self.new_task(self.loop, notmuch()) + self.loop.run_until_complete(t) + coro = format_coroutine(coro_qualname, 'done', src, + t._source_traceback) + self.assertEqual(repr(t), + "" % coro) + + def test_task_repr_autogenerated(self): + async def notmuch(): + return 123 + + t1 = self.new_task(self.loop, notmuch(), None) + t2 = self.new_task(self.loop, notmuch(), None) + self.assertNotEqual(repr(t1), repr(t2)) + + match1 = re.match(r"^' % re.escape(repr(fut))) + + fut.set_result(None) + self.loop.run_until_complete(task) + + def test_task_basics(self): + + async def outer(): + a = await inner1() + b = await inner2() + return a+b + + async def inner1(): + return 42 + + async def inner2(): + return 1000 + + t = outer() + self.assertEqual(self.loop.run_until_complete(t), 1042) + + def test_exception_chaining_after_await(self): + # Test that when awaiting on a task when an exception is already + # active, if the task raises an exception it will be chained + # with the original. + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + + async def raise_error(): + raise ValueError + + async def run(): + try: + raise KeyError(3) + except Exception as exc: + task = self.new_task(loop, raise_error()) + try: + await task + except Exception as exc: + self.assertEqual(type(exc), ValueError) + chained = exc.__context__ + self.assertEqual((type(chained), chained.args), + (KeyError, (3,))) + + try: + task = self.new_task(loop, run()) + loop.run_until_complete(task) + finally: + loop.close() + + def test_exception_chaining_after_await_with_context_cycle(self): + # Check trying to create an exception context cycle: + # https://bugs.python.org/issue40696 + has_cycle = None + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + + async def process_exc(exc): + raise exc + + async def run(): + nonlocal has_cycle + try: + raise KeyError('a') + except Exception as exc: + task = self.new_task(loop, process_exc(exc)) + try: + await task + except BaseException as exc: + has_cycle = (exc is exc.__context__) + # Prevent a hang if has_cycle is True. + exc.__context__ = None + + try: + task = self.new_task(loop, run()) + loop.run_until_complete(task) + finally: + loop.close() + # This also distinguishes from the initial has_cycle=None. + self.assertEqual(has_cycle, False) + + + def test_cancelling(self): + loop = asyncio.new_event_loop() + + async def task(): + await asyncio.sleep(10) + + try: + t = self.new_task(loop, task()) + self.assertFalse(t.cancelling()) + self.assertNotIn(" cancelling ", repr(t)) + self.assertTrue(t.cancel()) + self.assertTrue(t.cancelling()) + self.assertIn(" cancelling ", repr(t)) + + # Since we commented out two lines from Task.cancel(), + # this t.cancel() call now returns True. + # self.assertFalse(t.cancel()) + self.assertTrue(t.cancel()) + + with self.assertRaises(asyncio.CancelledError): + loop.run_until_complete(t) + finally: + loop.close() + + def test_uncancel_basic(self): + loop = asyncio.new_event_loop() + + async def task(): + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + asyncio.current_task().uncancel() + await asyncio.sleep(10) + + try: + t = self.new_task(loop, task()) + loop.run_until_complete(asyncio.sleep(0.01)) + + # Cancel first sleep + self.assertTrue(t.cancel()) + self.assertIn(" cancelling ", repr(t)) + self.assertEqual(t.cancelling(), 1) + self.assertFalse(t.cancelled()) # Task is still not complete + loop.run_until_complete(asyncio.sleep(0.01)) + + # after .uncancel() + self.assertNotIn(" cancelling ", repr(t)) + self.assertEqual(t.cancelling(), 0) + self.assertFalse(t.cancelled()) # Task is still not complete + + # Cancel second sleep + self.assertTrue(t.cancel()) + self.assertEqual(t.cancelling(), 1) + self.assertFalse(t.cancelled()) # Task is still not complete + with self.assertRaises(asyncio.CancelledError): + loop.run_until_complete(t) + self.assertTrue(t.cancelled()) # Finally, task complete + self.assertTrue(t.done()) + + # uncancel is no longer effective after the task is complete + t.uncancel() + self.assertTrue(t.cancelled()) + self.assertTrue(t.done()) + finally: + loop.close() + + def test_uncancel_structured_blocks(self): + # This test recreates the following high-level structure using uncancel():: + # + # async def make_request_with_timeout(): + # try: + # async with asyncio.timeout(1): + # # Structured block affected by the timeout: + # await make_request() + # await make_another_request() + # except TimeoutError: + # pass # There was a timeout + # # Outer code not affected by the timeout: + # await unrelated_code() + + loop = asyncio.new_event_loop() + + async def make_request_with_timeout(*, sleep: float, timeout: float): + task = asyncio.current_task() + loop = task.get_loop() + + timed_out = False + structured_block_finished = False + outer_code_reached = False + + def on_timeout(): + nonlocal timed_out + timed_out = True + task.cancel() + + timeout_handle = loop.call_later(timeout, on_timeout) + try: + try: + # Structured block affected by the timeout + await asyncio.sleep(sleep) + structured_block_finished = True + finally: + timeout_handle.cancel() + if ( + timed_out + and task.uncancel() == 0 + and type(sys.exception()) is asyncio.CancelledError + ): + # Note the five rules that are needed here to satisfy proper + # uncancellation: + # + # 1. handle uncancellation in a `finally:` block to allow for + # plain returns; + # 2. our `timed_out` flag is set, meaning that it was our event + # that triggered the need to uncancel the task, regardless of + # what exception is raised; + # 3. we can call `uncancel()` because *we* called `cancel()` + # before; + # 4. we call `uncancel()` but we only continue converting the + # CancelledError to TimeoutError if `uncancel()` caused the + # cancellation request count go down to 0. We need to look + # at the counter vs having a simple boolean flag because our + # code might have been nested (think multiple timeouts). See + # commit 7fce1063b6e5a366f8504e039a8ccdd6944625cd for + # details. + # 5. we only convert CancelledError to TimeoutError; for other + # exceptions raised due to the cancellation (like + # a ConnectionLostError from a database client), simply + # propagate them. + # + # Those checks need to take place in this exact order to make + # sure the `cancelling()` counter always stays in sync. + # + # Additionally, the original stimulus to `cancel()` the task + # needs to be unscheduled to avoid re-cancelling the task later. + # Here we do it by cancelling `timeout_handle` in the `finally:` + # block. + raise TimeoutError + except TimeoutError: + self.assertTrue(timed_out) + + # Outer code not affected by the timeout: + outer_code_reached = True + await asyncio.sleep(0) + return timed_out, structured_block_finished, outer_code_reached + + try: + # Test which timed out. + t1 = self.new_task(loop, make_request_with_timeout(sleep=10.0, timeout=0.1)) + timed_out, structured_block_finished, outer_code_reached = ( + loop.run_until_complete(t1) + ) + self.assertTrue(timed_out) + self.assertFalse(structured_block_finished) # it was cancelled + self.assertTrue(outer_code_reached) # task got uncancelled after leaving + # the structured block and continued until + # completion + self.assertEqual(t1.cancelling(), 0) # no pending cancellation of the outer task + + # Test which did not time out. + t2 = self.new_task(loop, make_request_with_timeout(sleep=0, timeout=10.0)) + timed_out, structured_block_finished, outer_code_reached = ( + loop.run_until_complete(t2) + ) + self.assertFalse(timed_out) + self.assertTrue(structured_block_finished) + self.assertTrue(outer_code_reached) + self.assertEqual(t2.cancelling(), 0) + finally: + loop.close() + + def test_uncancel_resets_must_cancel(self): + + async def coro(): + await fut + return 42 + + loop = asyncio.new_event_loop() + fut = asyncio.Future(loop=loop) + task = self.new_task(loop, coro()) + loop.run_until_complete(asyncio.sleep(0)) # Get task waiting for fut + fut.set_result(None) # Make task runnable + try: + task.cancel() # Enter cancelled state + self.assertEqual(task.cancelling(), 1) + self.assertTrue(task._must_cancel) + + task.uncancel() # Undo cancellation + self.assertEqual(task.cancelling(), 0) + self.assertFalse(task._must_cancel) + finally: + res = loop.run_until_complete(task) + self.assertEqual(res, 42) + loop.close() + + def test_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = self.new_test_loop(gen) + + async def task(): + await asyncio.sleep(10.0) + return 12 + + t = self.new_task(loop, task()) + loop.call_soon(t.cancel) + with self.assertRaises(asyncio.CancelledError): + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t.cancel()) + + def test_cancel_with_message_then_future_result(self): + # Test Future.result() after calling cancel() with a message. + cases = [ + ((), ()), + ((None,), ()), + (('my message',), ('my message',)), + # Non-string values should roundtrip. + ((5,), (5,)), + ] + for cancel_args, expected_args in cases: + with self.subTest(cancel_args=cancel_args): + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + + async def sleep(): + await asyncio.sleep(10) + + async def coro(): + task = self.new_task(loop, sleep()) + await asyncio.sleep(0) + task.cancel(*cancel_args) + done, pending = await asyncio.wait([task]) + task.result() + + task = self.new_task(loop, coro()) + with self.assertRaises(asyncio.CancelledError) as cm: + loop.run_until_complete(task) + exc = cm.exception + self.assertEqual(exc.args, expected_args) + + actual = get_innermost_context(exc) + self.assertEqual(actual, + (asyncio.CancelledError, expected_args, 0)) + + def test_cancel_with_message_then_future_exception(self): + # Test Future.exception() after calling cancel() with a message. + cases = [ + ((), ()), + ((None,), ()), + (('my message',), ('my message',)), + # Non-string values should roundtrip. + ((5,), (5,)), + ] + for cancel_args, expected_args in cases: + with self.subTest(cancel_args=cancel_args): + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + + async def sleep(): + await asyncio.sleep(10) + + async def coro(): + task = self.new_task(loop, sleep()) + await asyncio.sleep(0) + task.cancel(*cancel_args) + done, pending = await asyncio.wait([task]) + task.exception() + + task = self.new_task(loop, coro()) + with self.assertRaises(asyncio.CancelledError) as cm: + loop.run_until_complete(task) + exc = cm.exception + self.assertEqual(exc.args, expected_args) + + actual = get_innermost_context(exc) + self.assertEqual(actual, + (asyncio.CancelledError, expected_args, 0)) + + def test_cancellation_exception_context(self): + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + fut = loop.create_future() + + async def sleep(): + fut.set_result(None) + await asyncio.sleep(10) + + async def coro(): + inner_task = self.new_task(loop, sleep()) + await fut + loop.call_soon(inner_task.cancel, 'msg') + try: + await inner_task + except asyncio.CancelledError as ex: + raise ValueError("cancelled") from ex + + task = self.new_task(loop, coro()) + with self.assertRaises(ValueError) as cm: + loop.run_until_complete(task) + exc = cm.exception + self.assertEqual(exc.args, ('cancelled',)) + + actual = get_innermost_context(exc) + self.assertEqual(actual, + (asyncio.CancelledError, ('msg',), 1)) + + def test_cancel_with_message_before_starting_task(self): + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + + async def sleep(): + await asyncio.sleep(10) + + async def coro(): + task = self.new_task(loop, sleep()) + # We deliberately leave out the sleep here. + task.cancel('my message') + done, pending = await asyncio.wait([task]) + task.exception() + + task = self.new_task(loop, coro()) + with self.assertRaises(asyncio.CancelledError) as cm: + loop.run_until_complete(task) + exc = cm.exception + self.assertEqual(exc.args, ('my message',)) + + actual = get_innermost_context(exc) + self.assertEqual(actual, + (asyncio.CancelledError, ('my message',), 0)) + + def test_cancel_yield(self): + async def task(): + await asyncio.sleep(0) + await asyncio.sleep(0) + return 12 + + t = self.new_task(self.loop, task()) + test_utils.run_briefly(self.loop) # start coro + t.cancel() + self.assertRaises( + asyncio.CancelledError, self.loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t.cancel()) + + def test_cancel_inner_future(self): + f = self.new_future(self.loop) + + async def task(): + await f + return 12 + + t = self.new_task(self.loop, task()) + test_utils.run_briefly(self.loop) # start task + f.cancel() + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(t) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) + + def test_cancel_both_task_and_inner_future(self): + f = self.new_future(self.loop) + + async def task(): + await f + return 12 + + t = self.new_task(self.loop, task()) + test_utils.run_briefly(self.loop) + + f.cancel() + t.cancel() + + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(t) + + self.assertTrue(t.done()) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) + + def test_cancel_task_catching(self): + fut1 = self.new_future(self.loop) + fut2 = self.new_future(self.loop) + + async def task(): + await fut1 + try: + await fut2 + except asyncio.CancelledError: + return 42 + + t = self.new_task(self.loop, task()) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(t.cancelled()) + + def test_cancel_task_ignoring(self): + fut1 = self.new_future(self.loop) + fut2 = self.new_future(self.loop) + fut3 = self.new_future(self.loop) + + async def task(): + await fut1 + try: + await fut2 + except asyncio.CancelledError: + pass + res = await fut3 + return res + + t = self.new_task(self.loop, task()) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut3) # White-box test. + fut3.set_result(42) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(fut3.cancelled()) + self.assertFalse(t.cancelled()) + + def test_cancel_current_task(self): + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + + async def task(): + t.cancel() + self.assertTrue(t._must_cancel) # White-box test. + # The sleep should be cancelled immediately. + await asyncio.sleep(100) + return 12 + + t = self.new_task(loop, task()) + self.assertFalse(t.cancelled()) + self.assertRaises( + asyncio.CancelledError, loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t._must_cancel) # White-box test. + self.assertFalse(t.cancel()) + + def test_cancel_at_end(self): + """coroutine end right after task is cancelled""" + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + + async def task(): + t.cancel() + self.assertTrue(t._must_cancel) # White-box test. + return 12 + + t = self.new_task(loop, task()) + self.assertFalse(t.cancelled()) + self.assertRaises( + asyncio.CancelledError, loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t._must_cancel) # White-box test. + self.assertFalse(t.cancel()) + + def test_cancel_awaited_task(self): + # This tests for a relatively rare condition when + # a task cancellation is requested for a task which is not + # currently blocked, such as a task cancelling itself. + # In this situation we must ensure that whatever next future + # or task the cancelled task blocks on is cancelled correctly + # as well. See also bpo-34872. + loop = asyncio.new_event_loop() + self.addCleanup(lambda: loop.close()) + + task = nested_task = None + fut = self.new_future(loop) + + async def nested(): + await fut + + async def coro(): + nonlocal nested_task + # Create a sub-task and wait for it to run. + nested_task = self.new_task(loop, nested()) + await asyncio.sleep(0) + + # Request the current task to be cancelled. + task.cancel() + # Block on the nested task, which should be immediately + # cancelled. + await nested_task + + task = self.new_task(loop, coro()) + with self.assertRaises(asyncio.CancelledError): + loop.run_until_complete(task) + + self.assertTrue(task.cancelled()) + self.assertTrue(nested_task.cancelled()) + self.assertTrue(fut.cancelled()) + + def assert_text_contains(self, text, substr): + if substr not in text: + raise RuntimeError(f'text {substr!r} not found in:\n>>>{text}<<<') + + def test_cancel_traceback_for_future_result(self): + # When calling Future.result() on a cancelled task, check that the + # line of code that was interrupted is included in the traceback. + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + + async def nested(): + # This will get cancelled immediately. + await asyncio.sleep(10) + + async def coro(): + task = self.new_task(loop, nested()) + await asyncio.sleep(0) + task.cancel() + await task # search target + + task = self.new_task(loop, coro()) + try: + loop.run_until_complete(task) + except asyncio.CancelledError: + tb = traceback.format_exc() + self.assert_text_contains(tb, "await asyncio.sleep(10)") + # The intermediate await should also be included. + self.assert_text_contains(tb, "await task # search target") + else: + self.fail('CancelledError did not occur') + + def test_cancel_traceback_for_future_exception(self): + # When calling Future.exception() on a cancelled task, check that the + # line of code that was interrupted is included in the traceback. + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + + async def nested(): + # This will get cancelled immediately. + await asyncio.sleep(10) + + async def coro(): + task = self.new_task(loop, nested()) + await asyncio.sleep(0) + task.cancel() + done, pending = await asyncio.wait([task]) + task.exception() # search target + + task = self.new_task(loop, coro()) + try: + loop.run_until_complete(task) + except asyncio.CancelledError: + tb = traceback.format_exc() + self.assert_text_contains(tb, "await asyncio.sleep(10)") + # The intermediate await should also be included. + self.assert_text_contains(tb, + "task.exception() # search target") + else: + self.fail('CancelledError did not occur') + + def test_stop_while_run_in_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + when = yield 0.1 + self.assertAlmostEqual(0.3, when) + yield 0.1 + + loop = self.new_test_loop(gen) + + x = 0 + + async def task(): + nonlocal x + while x < 10: + await asyncio.sleep(0.1) + x += 1 + if x == 2: + loop.stop() + + t = self.new_task(loop, task()) + with self.assertRaises(RuntimeError) as cm: + loop.run_until_complete(t) + self.assertEqual(str(cm.exception), + 'Event loop stopped before Future completed.') + self.assertFalse(t.done()) + self.assertEqual(x, 2) + self.assertAlmostEqual(0.3, loop.time()) + + t.cancel() + self.assertRaises(asyncio.CancelledError, loop.run_until_complete, t) + + def test_log_traceback(self): + async def coro(): + pass + + task = self.new_task(self.loop, coro()) + with self.assertRaisesRegex(ValueError, 'can only be set to False'): + task._log_traceback = True + self.loop.run_until_complete(task) + + def test_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = self.new_test_loop(gen) + + a = self.new_task(loop, asyncio.sleep(0.1)) + b = self.new_task(loop, asyncio.sleep(0.15)) + + async def foo(): + done, pending = await asyncio.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + res = loop.run_until_complete(self.new_task(loop, foo())) + self.assertEqual(res, 42) + self.assertAlmostEqual(0.15, loop.time()) + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(self.new_task(loop, foo())) + self.assertAlmostEqual(0.15, loop.time()) + self.assertEqual(res, 42) + + def test_wait_duplicate_coroutines(self): + + async def coro(s): + return s + c = self.loop.create_task(coro('test')) + task = self.new_task( + self.loop, + asyncio.wait([c, c, self.loop.create_task(coro('spam'))])) + + done, pending = self.loop.run_until_complete(task) + + self.assertFalse(pending) + self.assertEqual(set(f.result() for f in done), {'test', 'spam'}) + + def test_wait_errors(self): + self.assertRaises( + ValueError, self.loop.run_until_complete, + asyncio.wait(set())) + + # -1 is an invalid return_when value + sleep_coro = asyncio.sleep(10.0) + wait_coro = asyncio.wait([sleep_coro], return_when=-1) + self.assertRaises(ValueError, + self.loop.run_until_complete, wait_coro) + + sleep_coro.close() + + def test_wait_first_completed(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = self.new_test_loop(gen) + + a = self.new_task(loop, asyncio.sleep(10.0)) + b = self.new_task(loop, asyncio.sleep(0.1)) + task = self.new_task( + loop, + asyncio.wait([b, a], return_when=asyncio.FIRST_COMPLETED)) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertFalse(a.done()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b])) + + def test_wait_really_done(self): + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + async def coro1(): + await asyncio.sleep(0) + + async def coro2(): + await asyncio.sleep(0) + await asyncio.sleep(0) + + a = self.new_task(self.loop, coro1()) + b = self.new_task(self.loop, coro2()) + task = self.new_task( + self.loop, + asyncio.wait([b, a], return_when=asyncio.FIRST_COMPLETED)) + + done, pending = self.loop.run_until_complete(task) + self.assertEqual({a, b}, done) + self.assertTrue(a.done()) + self.assertIsNone(a.result()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + + def test_wait_first_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = self.new_test_loop(gen) + + # first_exception, task already has exception + a = self.new_task(loop, asyncio.sleep(10.0)) + + async def exc(): + raise ZeroDivisionError('err') + + b = self.new_task(loop, exc()) + task = self.new_task( + loop, + asyncio.wait([b, a], return_when=asyncio.FIRST_EXCEPTION)) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b])) + + def test_wait_first_exception_in_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = self.new_test_loop(gen) + + # first_exception, exception during waiting + a = self.new_task(loop, asyncio.sleep(10.0)) + + async def exc(): + await asyncio.sleep(0.01) + raise ZeroDivisionError('err') + + b = self.new_task(loop, exc()) + task = asyncio.wait([b, a], return_when=asyncio.FIRST_EXCEPTION) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0.01, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b])) + + def test_wait_with_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = self.new_test_loop(gen) + + a = self.new_task(loop, asyncio.sleep(0.1)) + + async def sleeper(): + await asyncio.sleep(0.15) + raise ZeroDivisionError('really') + + b = self.new_task(loop, sleeper()) + + async def foo(): + done, pending = await asyncio.wait([b, a]) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + + loop.run_until_complete(self.new_task(loop, foo())) + self.assertAlmostEqual(0.15, loop.time()) + + loop.run_until_complete(self.new_task(loop, foo())) + self.assertAlmostEqual(0.15, loop.time()) + + def test_wait_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.11, when) + yield 0.11 + + loop = self.new_test_loop(gen) + + a = self.new_task(loop, asyncio.sleep(0.1)) + b = self.new_task(loop, asyncio.sleep(0.15)) + + async def foo(): + done, pending = await asyncio.wait([b, a], timeout=0.11) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + + loop.run_until_complete(self.new_task(loop, foo())) + self.assertAlmostEqual(0.11, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b])) + + def test_wait_concurrent_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = self.new_test_loop(gen) + + a = self.new_task(loop, asyncio.sleep(0.1)) + b = self.new_task(loop, asyncio.sleep(0.15)) + + done, pending = loop.run_until_complete( + asyncio.wait([b, a], timeout=0.1)) + + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b])) + + def test_wait_with_iterator_of_tasks(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = self.new_test_loop(gen) + + a = self.new_task(loop, asyncio.sleep(0.1)) + b = self.new_task(loop, asyncio.sleep(0.15)) + + async def foo(): + done, pending = await asyncio.wait(iter([b, a])) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + res = loop.run_until_complete(self.new_task(loop, foo())) + self.assertEqual(res, 42) + self.assertAlmostEqual(0.15, loop.time()) + + + def test_wait_generator(self): + async def func(a): + return a + + loop = self.new_test_loop() + + async def main(): + tasks = (self.new_task(loop, func(i)) for i in range(10)) + done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) + self.assertEqual(len(done), 10) + self.assertEqual(len(pending), 0) + + loop.run_until_complete(main()) + + + def test_as_completed(self): + + def gen(): + yield 0 + yield 0 + yield 0.01 + yield 0 + + async def sleeper(dt, x): + nonlocal time_shifted + await asyncio.sleep(dt) + completed.add(x) + if not time_shifted and 'a' in completed and 'b' in completed: + time_shifted = True + loop.advance_time(0.14) + return x + + async def try_iterator(awaitables): + values = [] + for f in asyncio.as_completed(awaitables): + values.append(await f) + return values + + async def try_async_iterator(awaitables): + values = [] + async for f in asyncio.as_completed(awaitables): + values.append(await f) + return values + + for foo in try_iterator, try_async_iterator: + with self.subTest(method=foo.__name__): + loop = self.new_test_loop(gen) + # disable "slow callback" warning + loop.slow_callback_duration = 1.0 + + completed = set() + time_shifted = False + + a = sleeper(0.01, 'a') + b = sleeper(0.01, 'b') + c = sleeper(0.15, 'c') + + res = loop.run_until_complete(self.new_task(loop, foo([b, c, a]))) + self.assertAlmostEqual(0.15, loop.time()) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + + def test_as_completed_same_tasks_in_as_out(self): + # Ensures that asynchronously iterating as_completed's iterator + # yields awaitables are the same awaitables that were passed in when + # those awaitables are futures. + async def try_async_iterator(awaitables): + awaitables_out = set() + async for out_aw in asyncio.as_completed(awaitables): + awaitables_out.add(out_aw) + return awaitables_out + + async def coro(i): + return i + + with contextlib.closing(asyncio.new_event_loop()) as loop: + # Coroutines shouldn't be yielded back as finished coroutines + # can't be re-used. + awaitables_in = frozenset( + (coro(0), coro(1), coro(2), coro(3)) + ) + awaitables_out = loop.run_until_complete( + try_async_iterator(awaitables_in) + ) + if awaitables_in - awaitables_out != awaitables_in: + raise self.failureException('Got original coroutines ' + 'out of as_completed iterator.') + + # Tasks should be yielded back. + coro_obj_a = coro('a') + task_b = loop.create_task(coro('b')) + coro_obj_c = coro('c') + task_d = loop.create_task(coro('d')) + awaitables_in = frozenset( + (coro_obj_a, task_b, coro_obj_c, task_d) + ) + awaitables_out = loop.run_until_complete( + try_async_iterator(awaitables_in) + ) + if awaitables_in & awaitables_out != {task_b, task_d}: + raise self.failureException('Only tasks should be yielded ' + 'from as_completed iterator ' + 'as-is.') + + def test_as_completed_with_timeout(self): + + def gen(): + yield + yield 0 + yield 0 + yield 0.1 + + async def try_iterator(): + values = [] + for f in asyncio.as_completed([a, b], timeout=0.12): + if values: + loop.advance_time(0.02) + try: + v = await f + values.append((1, v)) + except asyncio.TimeoutError as exc: + values.append((2, exc)) + return values + + async def try_async_iterator(): + values = [] + try: + async for f in asyncio.as_completed([a, b], timeout=0.12): + v = await f + values.append((1, v)) + loop.advance_time(0.02) + except asyncio.TimeoutError as exc: + values.append((2, exc)) + return values + + for foo in try_iterator, try_async_iterator: + with self.subTest(method=foo.__name__): + loop = self.new_test_loop(gen) + a = loop.create_task(asyncio.sleep(0.1, 'a')) + b = loop.create_task(asyncio.sleep(0.15, 'b')) + + res = loop.run_until_complete(self.new_task(loop, foo())) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertIsInstance(res[1][1], asyncio.TimeoutError) + self.assertAlmostEqual(0.12, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b])) + + def test_as_completed_with_unused_timeout(self): + + def gen(): + yield + yield 0 + yield 0.01 + + async def try_iterator(): + for f in asyncio.as_completed([a], timeout=1): + v = await f + self.assertEqual(v, 'a') + + async def try_async_iterator(): + async for f in asyncio.as_completed([a], timeout=1): + v = await f + self.assertEqual(v, 'a') + + for foo in try_iterator, try_async_iterator: + with self.subTest(method=foo.__name__): + a = asyncio.sleep(0.01, 'a') + loop = self.new_test_loop(gen) + loop.run_until_complete(self.new_task(loop, foo())) + loop.close() + + def test_as_completed_resume_iterator(self): + # Test that as_completed returns an iterator that can be resumed + # the next time iteration is performed (i.e. if __iter__ is called + # again) + async def try_iterator(awaitables): + iterations = 0 + iterator = asyncio.as_completed(awaitables) + collected = [] + for f in iterator: + collected.append(await f) + iterations += 1 + if iterations == 2: + break + self.assertEqual(len(collected), 2) + + # Resume same iterator: + for f in iterator: + collected.append(await f) + return collected + + async def try_async_iterator(awaitables): + iterations = 0 + iterator = asyncio.as_completed(awaitables) + collected = [] + async for f in iterator: + collected.append(await f) + iterations += 1 + if iterations == 2: + break + self.assertEqual(len(collected), 2) + + # Resume same iterator: + async for f in iterator: + collected.append(await f) + return collected + + async def coro(i): + return i + + with contextlib.closing(asyncio.new_event_loop()) as loop: + for foo in try_iterator, try_async_iterator: + with self.subTest(method=foo.__name__): + results = loop.run_until_complete( + foo((coro(0), coro(1), coro(2), coro(3))) + ) + self.assertCountEqual(results, (0, 1, 2, 3)) + + def test_as_completed_reverse_wait(self): + # Tests the plain iterator style of as_completed iteration to + # ensure that the first future awaited resolves to the first + # completed awaitable from the set we passed in, even if it wasn't + # the first future generated by as_completed. + def gen(): + yield 0 + yield 0.05 + yield 0 + + loop = self.new_test_loop(gen) + + a = asyncio.sleep(0.05, 'a') + b = asyncio.sleep(0.10, 'b') + fs = {a, b} + + async def test(): + futs = list(asyncio.as_completed(fs)) + self.assertEqual(len(futs), 2) + + x = await futs[1] + self.assertEqual(x, 'a') + self.assertAlmostEqual(0.05, loop.time()) + loop.advance_time(0.05) + y = await futs[0] + self.assertEqual(y, 'b') + self.assertAlmostEqual(0.10, loop.time()) + + loop.run_until_complete(test()) + + def test_as_completed_concurrent(self): + # Ensure that more than one future or coroutine yielded from + # as_completed can be awaited concurrently. + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0 + self.assertAlmostEqual(0.05, when) + yield 0.05 + + async def try_iterator(fs): + return list(asyncio.as_completed(fs)) + + async def try_async_iterator(fs): + return [f async for f in asyncio.as_completed(fs)] + + for runner in try_iterator, try_async_iterator: + with self.subTest(method=runner.__name__): + a = asyncio.sleep(0.05, 'a') + b = asyncio.sleep(0.05, 'b') + fs = {a, b} + + async def test(): + futs = await runner(fs) + self.assertEqual(len(futs), 2) + done, pending = await asyncio.wait( + [asyncio.ensure_future(fut) for fut in futs] + ) + self.assertEqual(set(f.result() for f in done), {'a', 'b'}) + + loop = self.new_test_loop(gen) + loop.run_until_complete(test()) + + def test_as_completed_duplicate_coroutines(self): + + async def coro(s): + return s + + async def try_iterator(): + result = [] + c = coro('ham') + for f in asyncio.as_completed([c, c, coro('spam')]): + result.append(await f) + return result + + async def try_async_iterator(): + result = [] + c = coro('ham') + async for f in asyncio.as_completed([c, c, coro('spam')]): + result.append(await f) + return result + + for runner in try_iterator, try_async_iterator: + with self.subTest(method=runner.__name__): + fut = self.new_task(self.loop, runner()) + self.loop.run_until_complete(fut) + result = fut.result() + self.assertEqual(set(result), {'ham', 'spam'}) + self.assertEqual(len(result), 2) + + def test_as_completed_coroutine_without_loop(self): + async def coro(): + return 42 + + a = coro() + self.addCleanup(a.close) + + with self.assertRaisesRegex(RuntimeError, 'no current event loop'): + futs = asyncio.as_completed([a]) + list(futs) + + def test_as_completed_coroutine_use_running_loop(self): + loop = self.new_test_loop() + + async def coro(): + return 42 + + async def test(): + futs = list(asyncio.as_completed([coro()])) + self.assertEqual(len(futs), 1) + self.assertEqual(await futs[0], 42) + + loop.run_until_complete(test()) + + def test_sleep(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0.05 + self.assertAlmostEqual(0.1, when) + yield 0.05 + + loop = self.new_test_loop(gen) + + async def sleeper(dt, arg): + await asyncio.sleep(dt/2) + res = await asyncio.sleep(dt/2, arg) + return res + + t = self.new_task(loop, sleeper(0.1, 'yeah')) + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + self.assertAlmostEqual(0.1, loop.time()) + + def test_sleep_when_delay_is_nan(self): + + def gen(): + yield + + loop = self.new_test_loop(gen) + + async def sleeper(): + await asyncio.sleep(float("nan")) + + t = self.new_task(loop, sleeper()) + + with self.assertRaises(ValueError): + loop.run_until_complete(t) + + def test_sleep_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = self.new_test_loop(gen) + + t = self.new_task(loop, asyncio.sleep(10.0, 'yeah')) + + handle = None + orig_call_later = loop.call_later + + def call_later(delay, callback, *args): + nonlocal handle + handle = orig_call_later(delay, callback, *args) + return handle + + loop.call_later = call_later + test_utils.run_briefly(loop) + + self.assertFalse(handle._cancelled) + + t.cancel() + test_utils.run_briefly(loop) + self.assertTrue(handle._cancelled) + + def test_task_cancel_sleeping_task(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(5000, when) + yield 0.1 + + loop = self.new_test_loop(gen) + + async def sleep(dt): + await asyncio.sleep(dt) + + async def doit(): + sleeper = self.new_task(loop, sleep(5000)) + loop.call_later(0.1, sleeper.cancel) + try: + await sleeper + except asyncio.CancelledError: + return 'cancelled' + else: + return 'slept in' + + doer = doit() + self.assertEqual(loop.run_until_complete(doer), 'cancelled') + self.assertAlmostEqual(0.1, loop.time()) + + def test_task_cancel_waiter_future(self): + fut = self.new_future(self.loop) + + async def coro(): + await fut + + task = self.new_task(self.loop, coro()) + test_utils.run_briefly(self.loop) + self.assertIs(task._fut_waiter, fut) + + task.cancel() + test_utils.run_briefly(self.loop) + self.assertRaises( + asyncio.CancelledError, self.loop.run_until_complete, task) + self.assertIsNone(task._fut_waiter) + self.assertTrue(fut.cancelled()) + + def test_task_set_methods(self): + async def notmuch(): + return 'ko' + + gen = notmuch() + task = self.new_task(self.loop, gen) + + with self.assertRaisesRegex(RuntimeError, 'not support set_result'): + task.set_result('ok') + + with self.assertRaisesRegex(RuntimeError, 'not support set_exception'): + task.set_exception(ValueError()) + + self.assertEqual( + self.loop.run_until_complete(task), + 'ko') + + def test_step_result_future(self): + # If coroutine returns future, task waits on this future. + + class Fut(asyncio.Future): + def __init__(self, *args, **kwds): + self.cb_added = False + super().__init__(*args, **kwds) + + def add_done_callback(self, *args, **kwargs): + self.cb_added = True + super().add_done_callback(*args, **kwargs) + + fut = Fut(loop=self.loop) + result = None + + async def wait_for_future(): + nonlocal result + result = await fut + + t = self.new_task(self.loop, wait_for_future()) + test_utils.run_briefly(self.loop) + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + test_utils.run_briefly(self.loop) + self.assertIs(res, result) + self.assertTrue(t.done()) + self.assertIsNone(t.result()) + + def test_baseexception_during_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = self.new_test_loop(gen) + + async def sleeper(): + await asyncio.sleep(10) + + base_exc = SystemExit() + + async def notmutch(): + try: + await sleeper() + except asyncio.CancelledError: + raise base_exc + + task = self.new_task(loop, notmutch()) + test_utils.run_briefly(loop) + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(SystemExit, test_utils.run_briefly, loop) + + self.assertTrue(task.done()) + self.assertFalse(task.cancelled()) + self.assertIs(task.exception(), base_exc) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(asyncio.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(asyncio.iscoroutinefunction(fn1)) + + async def fn2(): + pass + self.assertTrue(asyncio.iscoroutinefunction(fn2)) + + self.assertFalse(asyncio.iscoroutinefunction(mock.Mock())) + self.assertTrue(asyncio.iscoroutinefunction(mock.AsyncMock())) + + def test_coroutine_non_gen_function(self): + async def func(): + return 'test' + + self.assertTrue(asyncio.iscoroutinefunction(func)) + + coro = func() + self.assertTrue(asyncio.iscoroutine(coro)) + + res = self.loop.run_until_complete(coro) + self.assertEqual(res, 'test') + + def test_coroutine_non_gen_function_return_future(self): + fut = self.new_future(self.loop) + + async def func(): + return fut + + async def coro(): + fut.set_result('test') + + t1 = self.new_task(self.loop, func()) + t2 = self.new_task(self.loop, coro()) + res = self.loop.run_until_complete(t1) + self.assertEqual(res, fut) + self.assertIsNone(t2.result()) + + def test_current_task(self): + self.assertIsNone(asyncio.current_task(loop=self.loop)) + + async def coro(loop): + self.assertIs(asyncio.current_task(), task) + + self.assertIs(asyncio.current_task(None), task) + self.assertIs(asyncio.current_task(), task) + + task = self.new_task(self.loop, coro(self.loop)) + self.loop.run_until_complete(task) + self.assertIsNone(asyncio.current_task(loop=self.loop)) + + def test_current_task_with_interleaving_tasks(self): + self.assertIsNone(asyncio.current_task(loop=self.loop)) + + fut1 = self.new_future(self.loop) + fut2 = self.new_future(self.loop) + + async def coro1(loop): + self.assertTrue(asyncio.current_task() is task1) + await fut1 + self.assertTrue(asyncio.current_task() is task1) + fut2.set_result(True) + + async def coro2(loop): + self.assertTrue(asyncio.current_task() is task2) + fut1.set_result(True) + await fut2 + self.assertTrue(asyncio.current_task() is task2) + + task1 = self.new_task(self.loop, coro1(self.loop)) + task2 = self.new_task(self.loop, coro2(self.loop)) + + self.loop.run_until_complete(asyncio.wait((task1, task2))) + self.assertIsNone(asyncio.current_task(loop=self.loop)) + + # Some thorough tests for cancellation propagation through + # coroutines, tasks and wait(). + + def test_yield_future_passes_cancel(self): + # Cancelling outer() cancels inner() cancels waiter. + proof = 0 + waiter = self.new_future(self.loop) + + async def inner(): + nonlocal proof + try: + await waiter + except asyncio.CancelledError: + proof += 1 + raise + else: + self.fail('got past sleep() in inner()') + + async def outer(): + nonlocal proof + try: + await inner() + except asyncio.CancelledError: + proof += 100 # Expect this path. + else: + proof += 10 + + f = asyncio.ensure_future(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + self.loop.run_until_complete(f) + self.assertEqual(proof, 101) + self.assertTrue(waiter.cancelled()) + + def test_yield_wait_does_not_shield_cancel(self): + # Cancelling outer() makes wait() return early, leaves inner() + # running. + proof = 0 + waiter = self.new_future(self.loop) + + async def inner(): + nonlocal proof + await waiter + proof += 1 + + async def outer(): + nonlocal proof + with self.assertWarns(DeprecationWarning): + d, p = await asyncio.wait([asyncio.create_task(inner())]) + proof += 100 + + f = asyncio.ensure_future(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + self.assertRaises( + asyncio.CancelledError, self.loop.run_until_complete, f) + waiter.set_result(None) + test_utils.run_briefly(self.loop) + self.assertEqual(proof, 1) + + def test_shield_result(self): + inner = self.new_future(self.loop) + outer = asyncio.shield(inner) + inner.set_result(42) + res = self.loop.run_until_complete(outer) + self.assertEqual(res, 42) + + def test_shield_exception(self): + inner = self.new_future(self.loop) + outer = asyncio.shield(inner) + test_utils.run_briefly(self.loop) + exc = RuntimeError('expected') + inner.set_exception(exc) + test_utils.run_briefly(self.loop) + self.assertIs(outer.exception(), exc) + + def test_shield_cancel_inner(self): + inner = self.new_future(self.loop) + outer = asyncio.shield(inner) + test_utils.run_briefly(self.loop) + inner.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(outer.cancelled()) + + def test_shield_cancel_outer(self): + inner = self.new_future(self.loop) + outer = asyncio.shield(inner) + test_utils.run_briefly(self.loop) + outer.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(outer.cancelled()) + self.assertEqual(0, 0 if outer._callbacks is None else len(outer._callbacks)) + + def test_shield_shortcut(self): + fut = self.new_future(self.loop) + fut.set_result(42) + res = self.loop.run_until_complete(asyncio.shield(fut)) + self.assertEqual(res, 42) + + def test_shield_effect(self): + # Cancelling outer() does not affect inner(). + proof = 0 + waiter = self.new_future(self.loop) + + async def inner(): + nonlocal proof + await waiter + proof += 1 + + async def outer(): + nonlocal proof + await asyncio.shield(inner()) + proof += 100 + + f = asyncio.ensure_future(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(f) + waiter.set_result(None) + test_utils.run_briefly(self.loop) + self.assertEqual(proof, 1) + + def test_shield_gather(self): + child1 = self.new_future(self.loop) + child2 = self.new_future(self.loop) + parent = asyncio.gather(child1, child2) + outer = asyncio.shield(parent) + test_utils.run_briefly(self.loop) + outer.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(outer.cancelled()) + child1.set_result(1) + child2.set_result(2) + test_utils.run_briefly(self.loop) + self.assertEqual(parent.result(), [1, 2]) + + def test_gather_shield(self): + child1 = self.new_future(self.loop) + child2 = self.new_future(self.loop) + inner1 = asyncio.shield(child1) + inner2 = asyncio.shield(child2) + parent = asyncio.gather(inner1, inner2) + test_utils.run_briefly(self.loop) + parent.cancel() + # This should cancel inner1 and inner2 but bot child1 and child2. + test_utils.run_briefly(self.loop) + self.assertIsInstance(parent.exception(), asyncio.CancelledError) + self.assertTrue(inner1.cancelled()) + self.assertTrue(inner2.cancelled()) + child1.set_result(1) + child2.set_result(2) + test_utils.run_briefly(self.loop) + + def test_shield_coroutine_without_loop(self): + async def coro(): + return 42 + + inner = coro() + self.addCleanup(inner.close) + with self.assertRaisesRegex(RuntimeError, 'no current event loop'): + asyncio.shield(inner) + + def test_shield_coroutine_use_running_loop(self): + async def coro(): + return 42 + + async def test(): + return asyncio.shield(coro()) + outer = self.loop.run_until_complete(test()) + self.assertEqual(outer._loop, self.loop) + res = self.loop.run_until_complete(outer) + self.assertEqual(res, 42) + + def test_shield_coroutine_use_global_loop(self): + # Deprecated in 3.10, undeprecated in 3.12 + async def coro(): + return 42 + + asyncio.set_event_loop(self.loop) + self.addCleanup(asyncio.set_event_loop, None) + outer = asyncio.shield(coro()) + self.assertEqual(outer._loop, self.loop) + res = self.loop.run_until_complete(outer) + self.assertEqual(res, 42) + + def test_as_completed_invalid_args(self): + # as_completed() expects a list of futures, not a future instance + # TypeError should be raised either on iterator construction or first + # iteration + + # Plain iterator + fut = self.new_future(self.loop) + with self.assertRaises(TypeError): + iterator = asyncio.as_completed(fut) + next(iterator) + coro = coroutine_function() + with self.assertRaises(TypeError): + iterator = asyncio.as_completed(coro) + next(iterator) + coro.close() + + # Async iterator + async def try_async_iterator(aw): + async for f in asyncio.as_completed(aw): + break + + fut = self.new_future(self.loop) + with self.assertRaises(TypeError): + self.loop.run_until_complete(try_async_iterator(fut)) + coro = coroutine_function() + with self.assertRaises(TypeError): + self.loop.run_until_complete(try_async_iterator(coro)) + coro.close() + + def test_wait_invalid_args(self): + fut = self.new_future(self.loop) + + # wait() expects a list of futures, not a future instance + self.assertRaises(TypeError, self.loop.run_until_complete, + asyncio.wait(fut)) + coro = coroutine_function() + self.assertRaises(TypeError, self.loop.run_until_complete, + asyncio.wait(coro)) + coro.close() + + # wait() expects at least a future + self.assertRaises(ValueError, self.loop.run_until_complete, + asyncio.wait([])) + + def test_log_destroyed_pending_task(self): + Task = self.__class__.Task + + async def kill_me(loop): + future = self.new_future(loop) + await future + # at this point, the only reference to kill_me() task is + # the Task._wakeup() method in future._callbacks + raise Exception("code never reached") + + mock_handler = mock.Mock() + self.loop.set_debug(True) + self.loop.set_exception_handler(mock_handler) + + # schedule the task + coro = kill_me(self.loop) + task = asyncio.ensure_future(coro, loop=self.loop) + + self.assertEqual(asyncio.all_tasks(loop=self.loop), {task}) + + asyncio.set_event_loop(None) + + # execute the task so it waits for future + self.loop._run_once() + self.assertEqual(len(self.loop._ready), 0) + + coro = None + source_traceback = task._source_traceback + task = None + + # no more reference to kill_me() task: the task is destroyed by the GC + support.gc_collect() + + self.assertEqual(asyncio.all_tasks(loop=self.loop), set()) + + mock_handler.assert_called_with(self.loop, { + 'message': 'Task was destroyed but it is pending!', + 'task': mock.ANY, + 'source_traceback': source_traceback, + }) + mock_handler.reset_mock() + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_not_called_after_cancel(self, m_log): + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + + async def coro(): + raise TypeError + + async def runner(): + task = self.new_task(loop, coro()) + await asyncio.sleep(0.05) + task.cancel() + task = None + + loop.run_until_complete(runner()) + self.assertFalse(m_log.error.called) + + def test_task_source_traceback(self): + self.loop.set_debug(True) + + task = self.new_task(self.loop, coroutine_function()) + lineno = sys._getframe().f_lineno - 1 + self.assertIsInstance(task._source_traceback, list) + self.assertEqual(task._source_traceback[-2][:3], + (__file__, + lineno, + 'test_task_source_traceback')) + self.loop.run_until_complete(task) + + def test_cancel_gather_1(self): + """Ensure that a gathering future refuses to be cancelled once all + children are done""" + loop = asyncio.new_event_loop() + self.addCleanup(loop.close) + + fut = self.new_future(loop) + async def create(): + # The indirection fut->child_coro is needed since otherwise the + # gathering task is done at the same time as the child future + async def child_coro(): + return await fut + gather_future = asyncio.gather(child_coro()) + return asyncio.ensure_future(gather_future) + gather_task = loop.run_until_complete(create()) + + cancel_result = None + def cancelling_callback(_): + nonlocal cancel_result + cancel_result = gather_task.cancel() + fut.add_done_callback(cancelling_callback) + + fut.set_result(42) # calls the cancelling_callback after fut is done() + + # At this point the task should complete. + loop.run_until_complete(gather_task) + + # Python issue #26923: asyncio.gather drops cancellation + self.assertEqual(cancel_result, False) + self.assertFalse(gather_task.cancelled()) + self.assertEqual(gather_task.result(), [42]) + + def test_cancel_gather_2(self): + cases = [ + ((), ()), + ((None,), ()), + (('my message',), ('my message',)), + # Non-string values should roundtrip. + ((5,), (5,)), + ] + for cancel_args, expected_args in cases: + with self.subTest(cancel_args=cancel_args): + loop = asyncio.new_event_loop() + self.addCleanup(loop.close) + + async def test(): + time = 0 + while True: + time += 0.05 + await asyncio.gather(asyncio.sleep(0.05), + return_exceptions=True) + if time > 1: + return + + async def main(): + qwe = self.new_task(loop, test()) + await asyncio.sleep(0.2) + qwe.cancel(*cancel_args) + await qwe + + try: + loop.run_until_complete(main()) + except asyncio.CancelledError as exc: + self.assertEqual(exc.args, expected_args) + actual = get_innermost_context(exc) + self.assertEqual( + actual, + (asyncio.CancelledError, expected_args, 0), + ) + else: + self.fail( + 'gather() does not propagate CancelledError ' + 'raised by inner task to the gather() caller.' + ) + + def test_exception_traceback(self): + # See http://bugs.python.org/issue28843 + + async def foo(): + 1 / 0 + + async def main(): + task = self.new_task(self.loop, foo()) + await asyncio.sleep(0) # skip one loop iteration + self.assertIsNotNone(task.exception().__traceback__) + + self.loop.run_until_complete(main()) + + @mock.patch('asyncio.base_events.logger') + def test_error_in_call_soon(self, m_log): + def call_soon(callback, *args, **kwargs): + raise ValueError + self.loop.call_soon = call_soon + + async def coro(): + pass + + self.assertFalse(m_log.error.called) + + with self.assertRaises(ValueError): + gen = coro() + try: + self.new_task(self.loop, gen) + finally: + gen.close() + gc.collect() # For PyPy or other GCs. + + self.assertTrue(m_log.error.called) + message = m_log.error.call_args[0][0] + self.assertIn('Task was destroyed but it is pending', message) + + self.assertEqual(asyncio.all_tasks(self.loop), set()) + + def test_create_task_with_noncoroutine(self): + with self.assertRaisesRegex(TypeError, + "a coroutine was expected, got 123"): + self.new_task(self.loop, 123) + + # test it for the second time to ensure that caching + # in asyncio.iscoroutine() doesn't break things. + with self.assertRaisesRegex(TypeError, + "a coroutine was expected, got 123"): + self.new_task(self.loop, 123) + + def test_create_task_with_async_function(self): + + async def coro(): + pass + + task = self.new_task(self.loop, coro()) + self.assertIsInstance(task, self.Task) + self.loop.run_until_complete(task) + + # test it for the second time to ensure that caching + # in asyncio.iscoroutine() doesn't break things. + task = self.new_task(self.loop, coro()) + self.assertIsInstance(task, self.Task) + self.loop.run_until_complete(task) + + def test_create_task_with_asynclike_function(self): + task = self.new_task(self.loop, CoroLikeObject()) + self.assertIsInstance(task, self.Task) + self.assertEqual(self.loop.run_until_complete(task), 42) + + # test it for the second time to ensure that caching + # in asyncio.iscoroutine() doesn't break things. + task = self.new_task(self.loop, CoroLikeObject()) + self.assertIsInstance(task, self.Task) + self.assertEqual(self.loop.run_until_complete(task), 42) + + def test_bare_create_task(self): + + async def inner(): + return 1 + + async def coro(): + task = asyncio.create_task(inner()) + self.assertIsInstance(task, self.Task) + ret = await task + self.assertEqual(1, ret) + + self.loop.run_until_complete(coro()) + + def test_bare_create_named_task(self): + + async def coro_noop(): + pass + + async def coro(): + task = asyncio.create_task(coro_noop(), name='No-op') + self.assertEqual(task.get_name(), 'No-op') + await task + + self.loop.run_until_complete(coro()) + + def test_context_1(self): + cvar = contextvars.ContextVar('cvar', default='nope') + + async def sub(): + await asyncio.sleep(0.01) + self.assertEqual(cvar.get(), 'nope') + cvar.set('something else') + + async def main(): + self.assertEqual(cvar.get(), 'nope') + subtask = self.new_task(loop, sub()) + cvar.set('yes') + self.assertEqual(cvar.get(), 'yes') + await subtask + self.assertEqual(cvar.get(), 'yes') + + loop = asyncio.new_event_loop() + try: + task = self.new_task(loop, main()) + loop.run_until_complete(task) + finally: + loop.close() + + def test_context_2(self): + cvar = contextvars.ContextVar('cvar', default='nope') + + async def main(): + def fut_on_done(fut): + # This change must not pollute the context + # of the "main()" task. + cvar.set('something else') + + self.assertEqual(cvar.get(), 'nope') + + for j in range(2): + fut = self.new_future(loop) + fut.add_done_callback(fut_on_done) + cvar.set(f'yes{j}') + loop.call_soon(fut.set_result, None) + await fut + self.assertEqual(cvar.get(), f'yes{j}') + + for i in range(3): + # Test that task passed its context to add_done_callback: + cvar.set(f'yes{i}-{j}') + await asyncio.sleep(0.001) + self.assertEqual(cvar.get(), f'yes{i}-{j}') + + loop = asyncio.new_event_loop() + try: + task = self.new_task(loop, main()) + loop.run_until_complete(task) + finally: + loop.close() + + self.assertEqual(cvar.get(), 'nope') + + def test_context_3(self): + # Run 100 Tasks in parallel, each modifying cvar. + + cvar = contextvars.ContextVar('cvar', default=-1) + + async def sub(num): + for i in range(10): + cvar.set(num + i) + await asyncio.sleep(random.uniform(0.001, 0.05)) + self.assertEqual(cvar.get(), num + i) + + async def main(): + tasks = [] + for i in range(100): + task = loop.create_task(sub(random.randint(0, 10))) + tasks.append(task) + + await asyncio.gather(*tasks) + + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(main()) + finally: + loop.close() + + self.assertEqual(cvar.get(), -1) + + def test_context_4(self): + cvar = contextvars.ContextVar('cvar') + + async def coro(val): + await asyncio.sleep(0) + cvar.set(val) + + async def main(): + ret = [] + ctx = contextvars.copy_context() + ret.append(ctx.get(cvar)) + t1 = self.new_task(loop, coro(1), context=ctx) + await t1 + ret.append(ctx.get(cvar)) + t2 = self.new_task(loop, coro(2), context=ctx) + await t2 + ret.append(ctx.get(cvar)) + return ret + + loop = asyncio.new_event_loop() + try: + task = self.new_task(loop, main()) + ret = loop.run_until_complete(task) + finally: + loop.close() + + self.assertEqual([None, 1, 2], ret) + + def test_context_5(self): + cvar = contextvars.ContextVar('cvar') + + async def coro(val): + await asyncio.sleep(0) + cvar.set(val) + + async def main(): + ret = [] + ctx = contextvars.copy_context() + ret.append(ctx.get(cvar)) + t1 = asyncio.create_task(coro(1), context=ctx) + await t1 + ret.append(ctx.get(cvar)) + t2 = asyncio.create_task(coro(2), context=ctx) + await t2 + ret.append(ctx.get(cvar)) + return ret + + loop = asyncio.new_event_loop() + try: + task = self.new_task(loop, main()) + ret = loop.run_until_complete(task) + finally: + loop.close() + + self.assertEqual([None, 1, 2], ret) + + def test_context_6(self): + cvar = contextvars.ContextVar('cvar') + + async def coro(val): + await asyncio.sleep(0) + cvar.set(val) + + async def main(): + ret = [] + ctx = contextvars.copy_context() + ret.append(ctx.get(cvar)) + t1 = loop.create_task(coro(1), context=ctx) + await t1 + ret.append(ctx.get(cvar)) + t2 = loop.create_task(coro(2), context=ctx) + await t2 + ret.append(ctx.get(cvar)) + return ret + + loop = asyncio.new_event_loop() + try: + task = loop.create_task(main()) + ret = loop.run_until_complete(task) + finally: + loop.close() + + self.assertEqual([None, 1, 2], ret) + + def test_get_coro(self): + loop = asyncio.new_event_loop() + coro = coroutine_function() + try: + task = self.new_task(loop, coro) + loop.run_until_complete(task) + self.assertIs(task.get_coro(), coro) + finally: + loop.close() + + def test_get_context(self): + loop = asyncio.new_event_loop() + coro = coroutine_function() + context = contextvars.copy_context() + try: + task = self.new_task(loop, coro, context=context) + loop.run_until_complete(task) + self.assertIs(task.get_context(), context) + finally: + loop.close() + + def test_proper_refcounts(self): + # see: https://github.com/python/cpython/issues/126083 + class Break: + def __str__(self): + raise RuntimeError("break") + + obj = object() + initial_refcount = sys.getrefcount(obj) + + coro = coroutine_function() + with contextlib.closing(asyncio.EventLoop()) as loop: + task = asyncio.Task.__new__(asyncio.Task) + + for _ in range(5): + with self.assertRaisesRegex(RuntimeError, 'break'): + task.__init__(coro, loop=loop, context=obj, name=Break()) + + coro.close() + del task + + self.assertEqual(sys.getrefcount(obj), initial_refcount) + + +def add_subclass_tests(cls): + BaseTask = cls.Task + BaseFuture = cls.Future + + if BaseTask is None or BaseFuture is None: + return cls + + class CommonFuture: + def __init__(self, *args, **kwargs): + self.calls = collections.defaultdict(lambda: 0) + super().__init__(*args, **kwargs) + + def add_done_callback(self, *args, **kwargs): + self.calls['add_done_callback'] += 1 + return super().add_done_callback(*args, **kwargs) + + class Task(CommonFuture, BaseTask): + pass + + class Future(CommonFuture, BaseFuture): + pass + + def test_subclasses_ctask_cfuture(self): + fut = self.Future(loop=self.loop) + + async def func(): + self.loop.call_soon(lambda: fut.set_result('spam')) + return await fut + + task = self.Task(func(), loop=self.loop) + + result = self.loop.run_until_complete(task) + + self.assertEqual(result, 'spam') + + self.assertEqual( + dict(task.calls), + {'add_done_callback': 1}) + + self.assertEqual( + dict(fut.calls), + {'add_done_callback': 1}) + + # Add patched Task & Future back to the test case + cls.Task = Task + cls.Future = Future + + # Add an extra unit-test + cls.test_subclasses_ctask_cfuture = test_subclasses_ctask_cfuture + + # Disable the "test_task_source_traceback" test + # (the test is hardcoded for a particular call stack, which + # is slightly different for Task subclasses) + cls.test_task_source_traceback = None + + return cls + + +class SetMethodsTest: + + def test_set_result_causes_invalid_state(self): + Future = type(self).Future + self.loop.call_exception_handler = exc_handler = mock.Mock() + + async def foo(): + await asyncio.sleep(0.1) + return 10 + + coro = foo() + task = self.new_task(self.loop, coro) + Future.set_result(task, 'spam') + + self.assertEqual( + self.loop.run_until_complete(task), + 'spam') + + exc_handler.assert_called_once() + exc = exc_handler.call_args[0][0]['exception'] + with self.assertRaisesRegex(asyncio.InvalidStateError, + r'step\(\): already done'): + raise exc + + coro.close() + + def test_set_exception_causes_invalid_state(self): + class MyExc(Exception): + pass + + Future = type(self).Future + self.loop.call_exception_handler = exc_handler = mock.Mock() + + async def foo(): + await asyncio.sleep(0.1) + return 10 + + coro = foo() + task = self.new_task(self.loop, coro) + Future.set_exception(task, MyExc()) + + with self.assertRaises(MyExc): + self.loop.run_until_complete(task) + + exc_handler.assert_called_once() + exc = exc_handler.call_args[0][0]['exception'] + with self.assertRaisesRegex(asyncio.InvalidStateError, + r'step\(\): already done'): + raise exc + + coro.close() + + +@unittest.skipUnless(hasattr(futures, '_CFuture') and + hasattr(tasks, '_CTask'), + 'requires the C _asyncio module') +class CTask_CFuture_Tests(BaseTaskTests, SetMethodsTest, + test_utils.TestCase): + + Task = getattr(tasks, '_CTask', None) + Future = getattr(futures, '_CFuture', None) + + @support.refcount_test + def test_refleaks_in_task___init__(self): + gettotalrefcount = support.get_attribute(sys, 'gettotalrefcount') + async def coro(): + pass + task = self.new_task(self.loop, coro()) + self.loop.run_until_complete(task) + refs_before = gettotalrefcount() + for i in range(100): + task.__init__(coro(), loop=self.loop) + self.loop.run_until_complete(task) + self.assertAlmostEqual(gettotalrefcount() - refs_before, 0, delta=10) + + def test_del__log_destroy_pending_segfault(self): + async def coro(): + pass + task = self.new_task(self.loop, coro()) + self.loop.run_until_complete(task) + with self.assertRaises(AttributeError): + del task._log_destroy_pending + + +@unittest.skipUnless(hasattr(futures, '_CFuture') and + hasattr(tasks, '_CTask'), + 'requires the C _asyncio module') +@add_subclass_tests +class CTask_CFuture_SubclassTests(BaseTaskTests, test_utils.TestCase): + + Task = getattr(tasks, '_CTask', None) + Future = getattr(futures, '_CFuture', None) + + +@unittest.skipUnless(hasattr(tasks, '_CTask'), + 'requires the C _asyncio module') +@add_subclass_tests +class CTaskSubclass_PyFuture_Tests(BaseTaskTests, test_utils.TestCase): + + Task = getattr(tasks, '_CTask', None) + Future = futures._PyFuture + + +@unittest.skipUnless(hasattr(futures, '_CFuture'), + 'requires the C _asyncio module') +@add_subclass_tests +class PyTask_CFutureSubclass_Tests(BaseTaskTests, test_utils.TestCase): + + Future = getattr(futures, '_CFuture', None) + Task = tasks._PyTask + + +@unittest.skipUnless(hasattr(tasks, '_CTask'), + 'requires the C _asyncio module') +class CTask_PyFuture_Tests(BaseTaskTests, test_utils.TestCase): + + Task = getattr(tasks, '_CTask', None) + Future = futures._PyFuture + + +@unittest.skipUnless(hasattr(futures, '_CFuture'), + 'requires the C _asyncio module') +class PyTask_CFuture_Tests(BaseTaskTests, test_utils.TestCase): + + Task = tasks._PyTask + Future = getattr(futures, '_CFuture', None) + + +# TODO: RUSTPYTHON +# class PyTask_PyFuture_Tests(BaseTaskTests, SetMethodsTest, +# test_utils.TestCase): + +# Task = tasks._PyTask +# Future = futures._PyFuture + + +# TODO: RUSTPYTHON +# @add_subclass_tests +# class PyTask_PyFuture_SubclassTests(BaseTaskTests, test_utils.TestCase): +# Task = tasks._PyTask +# Future = futures._PyFuture + + +@unittest.skipUnless(hasattr(tasks, '_CTask'), + 'requires the C _asyncio module') +class CTask_Future_Tests(test_utils.TestCase): + + def test_foobar(self): + class Fut(asyncio.Future): + @property + def get_loop(self): + raise AttributeError + + async def coro(): + await fut + return 'spam' + + self.loop = asyncio.new_event_loop() + try: + fut = Fut(loop=self.loop) + self.loop.call_later(0.1, fut.set_result, 1) + task = self.loop.create_task(coro()) + res = self.loop.run_until_complete(task) + finally: + self.loop.close() + + self.assertEqual(res, 'spam') + + +class BaseTaskIntrospectionTests: + _register_task = None + _unregister_task = None + _enter_task = None + _leave_task = None + + def test__register_task_1(self): + class TaskLike: + @property + def _loop(self): + return loop + + def done(self): + return False + + task = TaskLike() + loop = mock.Mock() + + self.assertEqual(asyncio.all_tasks(loop), set()) + self._register_task(task) + self.assertEqual(asyncio.all_tasks(loop), {task}) + self._unregister_task(task) + + def test__register_task_2(self): + class TaskLike: + def get_loop(self): + return loop + + def done(self): + return False + + task = TaskLike() + loop = mock.Mock() + + self.assertEqual(asyncio.all_tasks(loop), set()) + self._register_task(task) + self.assertEqual(asyncio.all_tasks(loop), {task}) + self._unregister_task(task) + + def test__register_task_3(self): + class TaskLike: + def get_loop(self): + return loop + + def done(self): + return True + + task = TaskLike() + loop = mock.Mock() + + self.assertEqual(asyncio.all_tasks(loop), set()) + self._register_task(task) + self.assertEqual(asyncio.all_tasks(loop), set()) + self._unregister_task(task) + + def test__enter_task(self): + task = mock.Mock() + loop = mock.Mock() + self.assertIsNone(asyncio.current_task(loop)) + self._enter_task(loop, task) + self.assertIs(asyncio.current_task(loop), task) + self._leave_task(loop, task) + + def test__enter_task_failure(self): + task1 = mock.Mock() + task2 = mock.Mock() + loop = mock.Mock() + self._enter_task(loop, task1) + with self.assertRaises(RuntimeError): + self._enter_task(loop, task2) + self.assertIs(asyncio.current_task(loop), task1) + self._leave_task(loop, task1) + + def test__leave_task(self): + task = mock.Mock() + loop = mock.Mock() + self._enter_task(loop, task) + self._leave_task(loop, task) + self.assertIsNone(asyncio.current_task(loop)) + + def test__leave_task_failure1(self): + task1 = mock.Mock() + task2 = mock.Mock() + loop = mock.Mock() + self._enter_task(loop, task1) + with self.assertRaises(RuntimeError): + self._leave_task(loop, task2) + self.assertIs(asyncio.current_task(loop), task1) + self._leave_task(loop, task1) + + def test__leave_task_failure2(self): + task = mock.Mock() + loop = mock.Mock() + with self.assertRaises(RuntimeError): + self._leave_task(loop, task) + self.assertIsNone(asyncio.current_task(loop)) + + def test__unregister_task(self): + task = mock.Mock() + loop = mock.Mock() + task.get_loop = lambda: loop + self._register_task(task) + self._unregister_task(task) + self.assertEqual(asyncio.all_tasks(loop), set()) + + def test__unregister_task_not_registered(self): + task = mock.Mock() + loop = mock.Mock() + self._unregister_task(task) + self.assertEqual(asyncio.all_tasks(loop), set()) + + +class PyIntrospectionTests(test_utils.TestCase, BaseTaskIntrospectionTests): + _register_task = staticmethod(tasks._py_register_task) + _unregister_task = staticmethod(tasks._py_unregister_task) + _enter_task = staticmethod(tasks._py_enter_task) + _leave_task = staticmethod(tasks._py_leave_task) + + +@unittest.skipUnless(hasattr(tasks, '_c_register_task'), + 'requires the C _asyncio module') +class CIntrospectionTests(test_utils.TestCase, BaseTaskIntrospectionTests): + if hasattr(tasks, '_c_register_task'): + _register_task = staticmethod(tasks._c_register_task) + _unregister_task = staticmethod(tasks._c_unregister_task) + _enter_task = staticmethod(tasks._c_enter_task) + _leave_task = staticmethod(tasks._c_leave_task) + else: + _register_task = _unregister_task = _enter_task = _leave_task = None + + +class BaseCurrentLoopTests: + current_task = None + + def setUp(self): + super().setUp() + self.loop = asyncio.new_event_loop() + self.set_event_loop(self.loop) + + def new_task(self, coro): + raise NotImplementedError + + def test_current_task_no_running_loop(self): + self.assertIsNone(self.current_task(loop=self.loop)) + + def test_current_task_no_running_loop_implicit(self): + with self.assertRaisesRegex(RuntimeError, 'no running event loop'): + self.current_task() + + def test_current_task_with_implicit_loop(self): + async def coro(): + self.assertIs(self.current_task(loop=self.loop), task) + + self.assertIs(self.current_task(None), task) + self.assertIs(self.current_task(), task) + + task = self.new_task(coro()) + self.loop.run_until_complete(task) + self.assertIsNone(self.current_task(loop=self.loop)) + + +class PyCurrentLoopTests(BaseCurrentLoopTests, test_utils.TestCase): + current_task = staticmethod(tasks._py_current_task) + + def new_task(self, coro): + return tasks._PyTask(coro, loop=self.loop) + + +@unittest.skipUnless(hasattr(tasks, '_CTask') and + hasattr(tasks, '_c_current_task'), + 'requires the C _asyncio module') +class CCurrentLoopTests(BaseCurrentLoopTests, test_utils.TestCase): + if hasattr(tasks, '_c_current_task'): + current_task = staticmethod(tasks._c_current_task) + else: + current_task = None + + def new_task(self, coro): + return getattr(tasks, '_CTask')(coro, loop=self.loop) + + +class GenericTaskTests(test_utils.TestCase): + + def test_future_subclass(self): + self.assertTrue(issubclass(asyncio.Task, asyncio.Future)) + + @support.cpython_only + def test_asyncio_module_compiled(self): + # Because of circular imports it's easy to make _asyncio + # module non-importable. This is a simple test that will + # fail on systems where C modules were successfully compiled + # (hence the test for _functools etc), but _asyncio somehow didn't. + try: + import _functools + import _json + import _pickle + except ImportError: + self.skipTest('C modules are not available') + else: + try: + import _asyncio + except ImportError: + self.fail('_asyncio module is missing') + + +class GatherTestsBase: + + def setUp(self): + super().setUp() + self.one_loop = self.new_test_loop() + self.other_loop = self.new_test_loop() + self.set_event_loop(self.one_loop, cleanup=False) + + def _run_loop(self, loop): + while loop._ready: + test_utils.run_briefly(loop) + + def _check_success(self, **kwargs): + a, b, c = [self.one_loop.create_future() for i in range(3)] + fut = self._gather(*self.wrap_futures(a, b, c), **kwargs) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + b.set_result(1) + a.set_result(2) + self._run_loop(self.one_loop) + self.assertEqual(cb.called, False) + self.assertFalse(fut.done()) + c.set_result(3) + self._run_loop(self.one_loop) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [2, 1, 3]) + + def test_success(self): + self._check_success() + self._check_success(return_exceptions=False) + + def test_result_exception_success(self): + self._check_success(return_exceptions=True) + + def test_one_exception(self): + a, b, c, d, e = [self.one_loop.create_future() for i in range(5)] + fut = self._gather(*self.wrap_futures(a, b, c, d, e)) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + a.set_result(1) + b.set_exception(exc) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertIs(fut.exception(), exc) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + e.exception() + + def test_return_exceptions(self): + a, b, c, d = [self.one_loop.create_future() for i in range(4)] + fut = self._gather(*self.wrap_futures(a, b, c, d), + return_exceptions=True) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + exc2 = RuntimeError() + b.set_result(1) + c.set_exception(exc) + a.set_result(3) + self._run_loop(self.one_loop) + self.assertFalse(fut.done()) + d.set_exception(exc2) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [3, 1, exc, exc2]) + + def test_env_var_debug(self): + code = '\n'.join(( + 'import asyncio.coroutines', + 'print(asyncio.coroutines._is_debug_mode())')) + + # Test with -E to not fail if the unit test was run with + # PYTHONASYNCIODEBUG set to a non-empty string + sts, stdout, stderr = assert_python_ok('-E', '-c', code) + self.assertEqual(stdout.rstrip(), b'False') + + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='', + PYTHONDEVMODE='') + self.assertEqual(stdout.rstrip(), b'False') + + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='1', + PYTHONDEVMODE='') + self.assertEqual(stdout.rstrip(), b'True') + + sts, stdout, stderr = assert_python_ok('-E', '-c', code, + PYTHONASYNCIODEBUG='1', + PYTHONDEVMODE='') + self.assertEqual(stdout.rstrip(), b'False') + + # -X dev + sts, stdout, stderr = assert_python_ok('-E', '-X', 'dev', + '-c', code) + self.assertEqual(stdout.rstrip(), b'True') + + +class FutureGatherTests(GatherTestsBase, test_utils.TestCase): + + def wrap_futures(self, *futures): + return futures + + def _gather(self, *args, **kwargs): + return asyncio.gather(*args, **kwargs) + + def test_constructor_empty_sequence_without_loop(self): + with self.assertRaisesRegex(RuntimeError, 'no current event loop'): + asyncio.gather() + + def test_constructor_empty_sequence_use_running_loop(self): + async def gather(): + return asyncio.gather() + fut = self.one_loop.run_until_complete(gather()) + self.assertIsInstance(fut, asyncio.Future) + self.assertIs(fut._loop, self.one_loop) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + self.assertEqual(fut.result(), []) + + def test_constructor_empty_sequence_use_global_loop(self): + # Deprecated in 3.10, undeprecated in 3.12 + asyncio.set_event_loop(self.one_loop) + self.addCleanup(asyncio.set_event_loop, None) + fut = asyncio.gather() + self.assertIsInstance(fut, asyncio.Future) + self.assertIs(fut._loop, self.one_loop) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + self.assertEqual(fut.result(), []) + + def test_constructor_heterogenous_futures(self): + fut1 = self.one_loop.create_future() + fut2 = self.other_loop.create_future() + with self.assertRaises(ValueError): + asyncio.gather(fut1, fut2) + + def test_constructor_homogenous_futures(self): + children = [self.other_loop.create_future() for i in range(3)] + fut = asyncio.gather(*children) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + fut = asyncio.gather(*children) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + + def test_one_cancellation(self): + a, b, c, d, e = [self.one_loop.create_future() for i in range(5)] + fut = asyncio.gather(a, b, c, d, e) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + a.set_result(1) + b.cancel() + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertFalse(fut.cancelled()) + self.assertIsInstance(fut.exception(), asyncio.CancelledError) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + e.exception() + + def test_result_exception_one_cancellation(self): + a, b, c, d, e, f = [self.one_loop.create_future() + for i in range(6)] + fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + a.set_result(1) + zde = ZeroDivisionError() + b.set_exception(zde) + c.cancel() + self._run_loop(self.one_loop) + self.assertFalse(fut.done()) + d.set_result(3) + e.cancel() + rte = RuntimeError() + f.set_exception(rte) + res = self.one_loop.run_until_complete(fut) + self.assertIsInstance(res[2], asyncio.CancelledError) + self.assertIsInstance(res[4], asyncio.CancelledError) + res[2] = res[4] = None + self.assertEqual(res, [1, zde, None, 3, None, rte]) + cb.assert_called_once_with(fut) + + +class CoroutineGatherTests(GatherTestsBase, test_utils.TestCase): + + def wrap_futures(self, *futures): + coros = [] + for fut in futures: + async def coro(fut=fut): + return await fut + coros.append(coro()) + return coros + + def _gather(self, *args, **kwargs): + async def coro(): + return asyncio.gather(*args, **kwargs) + return self.one_loop.run_until_complete(coro()) + + def test_constructor_without_loop(self): + async def coro(): + return 'abc' + gen1 = coro() + self.addCleanup(gen1.close) + gen2 = coro() + self.addCleanup(gen2.close) + with self.assertRaisesRegex(RuntimeError, 'no current event loop'): + asyncio.gather(gen1, gen2) + + def test_constructor_use_running_loop(self): + async def coro(): + return 'abc' + gen1 = coro() + gen2 = coro() + async def gather(): + return asyncio.gather(gen1, gen2) + fut = self.one_loop.run_until_complete(gather()) + self.assertIs(fut._loop, self.one_loop) + self.one_loop.run_until_complete(fut) + + def test_constructor_use_global_loop(self): + # Deprecated in 3.10, undeprecated in 3.12 + async def coro(): + return 'abc' + asyncio.set_event_loop(self.other_loop) + self.addCleanup(asyncio.set_event_loop, None) + gen1 = coro() + gen2 = coro() + fut = asyncio.gather(gen1, gen2) + self.assertIs(fut._loop, self.other_loop) + self.other_loop.run_until_complete(fut) + + def test_duplicate_coroutines(self): + async def coro(s): + return s + c = coro('abc') + fut = self._gather(c, c, coro('def'), c) + self._run_loop(self.one_loop) + self.assertEqual(fut.result(), ['abc', 'abc', 'def', 'abc']) + + def test_cancellation_broadcast(self): + # Cancelling outer() cancels all children. + proof = 0 + waiter = self.one_loop.create_future() + + async def inner(): + nonlocal proof + await waiter + proof += 1 + + child1 = asyncio.ensure_future(inner(), loop=self.one_loop) + child2 = asyncio.ensure_future(inner(), loop=self.one_loop) + gatherer = None + + async def outer(): + nonlocal proof, gatherer + gatherer = asyncio.gather(child1, child2) + await gatherer + proof += 100 + + f = asyncio.ensure_future(outer(), loop=self.one_loop) + test_utils.run_briefly(self.one_loop) + self.assertTrue(f.cancel()) + with self.assertRaises(asyncio.CancelledError): + self.one_loop.run_until_complete(f) + self.assertFalse(gatherer.cancel()) + self.assertTrue(waiter.cancelled()) + self.assertTrue(child1.cancelled()) + self.assertTrue(child2.cancelled()) + test_utils.run_briefly(self.one_loop) + self.assertEqual(proof, 0) + + def test_exception_marking(self): + # Test for the first line marked "Mark exception retrieved." + + async def inner(f): + await f + raise RuntimeError('should not be ignored') + + a = self.one_loop.create_future() + b = self.one_loop.create_future() + + async def outer(): + await asyncio.gather(inner(a), inner(b)) + + f = asyncio.ensure_future(outer(), loop=self.one_loop) + test_utils.run_briefly(self.one_loop) + a.set_result(None) + test_utils.run_briefly(self.one_loop) + b.set_result(None) + test_utils.run_briefly(self.one_loop) + self.assertIsInstance(f.exception(), RuntimeError) + + def test_issue46672(self): + with mock.patch( + 'asyncio.base_events.BaseEventLoop.call_exception_handler', + ): + async def coro(s): + return s + c = coro('abc') + + with self.assertRaises(TypeError): + self._gather(c, {}) + self._run_loop(self.one_loop) + # NameError should not happen: + self.one_loop.call_exception_handler.assert_not_called() + + +class RunCoroutineThreadsafeTests(test_utils.TestCase): + """Test case for asyncio.run_coroutine_threadsafe.""" + + def setUp(self): + super().setUp() + self.loop = asyncio.new_event_loop() + self.set_event_loop(self.loop) # Will cleanup properly + + async def add(self, a, b, fail=False, cancel=False): + """Wait 0.05 second and return a + b.""" + await asyncio.sleep(0.05) + if fail: + raise RuntimeError("Fail!") + if cancel: + asyncio.current_task(self.loop).cancel() + await asyncio.sleep(0) + return a + b + + def target(self, fail=False, cancel=False, timeout=None, + advance_coro=False): + """Run add coroutine in the event loop.""" + coro = self.add(1, 2, fail=fail, cancel=cancel) + future = asyncio.run_coroutine_threadsafe(coro, self.loop) + if advance_coro: + # this is for test_run_coroutine_threadsafe_task_factory_exception; + # otherwise it spills errors and breaks **other** unittests, since + # 'target' is interacting with threads. + + # With this call, `coro` will be advanced. + self.loop.call_soon_threadsafe(coro.send, None) + try: + return future.result(timeout) + finally: + future.done() or future.cancel() + + def test_run_coroutine_threadsafe(self): + """Test coroutine submission from a thread to an event loop.""" + future = self.loop.run_in_executor(None, self.target) + result = self.loop.run_until_complete(future) + self.assertEqual(result, 3) + + def test_run_coroutine_threadsafe_with_exception(self): + """Test coroutine submission from a thread to an event loop + when an exception is raised.""" + future = self.loop.run_in_executor(None, self.target, True) + with self.assertRaises(RuntimeError) as exc_context: + self.loop.run_until_complete(future) + self.assertIn("Fail!", exc_context.exception.args) + + def test_run_coroutine_threadsafe_with_timeout(self): + """Test coroutine submission from a thread to an event loop + when a timeout is raised.""" + callback = lambda: self.target(timeout=0) + future = self.loop.run_in_executor(None, callback) + with self.assertRaises(asyncio.TimeoutError): + self.loop.run_until_complete(future) + test_utils.run_briefly(self.loop) + # Check that there's no pending task (add has been cancelled) + for task in asyncio.all_tasks(self.loop): + self.assertTrue(task.done()) + + def test_run_coroutine_threadsafe_task_cancelled(self): + """Test coroutine submission from a thread to an event loop + when the task is cancelled.""" + callback = lambda: self.target(cancel=True) + future = self.loop.run_in_executor(None, callback) + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(future) + + def test_run_coroutine_threadsafe_task_factory_exception(self): + """Test coroutine submission from a thread to an event loop + when the task factory raise an exception.""" + + def task_factory(loop, coro): + raise NameError + + run = self.loop.run_in_executor( + None, lambda: self.target(advance_coro=True)) + + # Set exception handler + callback = test_utils.MockCallback() + self.loop.set_exception_handler(callback) + + # Set corrupted task factory + self.addCleanup(self.loop.set_task_factory, + self.loop.get_task_factory()) + self.loop.set_task_factory(task_factory) + + # Run event loop + with self.assertRaises(NameError) as exc_context: + self.loop.run_until_complete(run) + + # Check exceptions + self.assertEqual(len(callback.call_args_list), 1) + (loop, context), kwargs = callback.call_args + self.assertEqual(context['exception'], exc_context.exception) + + +class SleepTests(test_utils.TestCase): + def setUp(self): + super().setUp() + self.loop = asyncio.new_event_loop() + self.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + self.loop = None + super().tearDown() + + def test_sleep_zero(self): + result = 0 + + def inc_result(num): + nonlocal result + result += num + + async def coro(): + self.loop.call_soon(inc_result, 1) + self.assertEqual(result, 0) + num = await asyncio.sleep(0, result=10) + self.assertEqual(result, 1) # inc'ed by call_soon + inc_result(num) # num should be 11 + + self.loop.run_until_complete(coro()) + self.assertEqual(result, 11) + + +class CompatibilityTests(test_utils.TestCase): + # Tests for checking a bridge between old-styled coroutines + # and async/await syntax + + def setUp(self): + super().setUp() + self.loop = asyncio.new_event_loop() + self.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + self.loop = None + super().tearDown() + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_threads.py b/Lib/test/test_asyncio/test_threads.py new file mode 100644 index 00000000000..774380270a7 --- /dev/null +++ b/Lib/test/test_asyncio/test_threads.py @@ -0,0 +1,66 @@ +"""Tests for asyncio/threads.py""" + +import asyncio +import unittest + +from contextvars import ContextVar +from unittest import mock + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class ToThreadTests(unittest.IsolatedAsyncioTestCase): + async def test_to_thread(self): + result = await asyncio.to_thread(sum, [40, 2]) + self.assertEqual(result, 42) + + async def test_to_thread_exception(self): + def raise_runtime(): + raise RuntimeError("test") + + with self.assertRaisesRegex(RuntimeError, "test"): + await asyncio.to_thread(raise_runtime) + + async def test_to_thread_once(self): + func = mock.Mock() + + await asyncio.to_thread(func) + func.assert_called_once() + + async def test_to_thread_concurrent(self): + calls = [] + def func(): + calls.append(1) + + futs = [] + for _ in range(10): + fut = asyncio.to_thread(func) + futs.append(fut) + await asyncio.gather(*futs) + + self.assertEqual(sum(calls), 10) + + async def test_to_thread_args_kwargs(self): + # Unlike run_in_executor(), to_thread() should directly accept kwargs. + func = mock.Mock() + + await asyncio.to_thread(func, 'test', something=True) + + func.assert_called_once_with('test', something=True) + + async def test_to_thread_contextvars(self): + test_ctx = ContextVar('test_ctx') + + def get_ctx(): + return test_ctx.get() + + test_ctx.set('parrot') + result = await asyncio.to_thread(get_ctx) + + self.assertEqual(result, 'parrot') + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_asyncio/test_timeouts.py b/Lib/test/test_asyncio/test_timeouts.py new file mode 100644 index 00000000000..caab9e3917a --- /dev/null +++ b/Lib/test/test_asyncio/test_timeouts.py @@ -0,0 +1,420 @@ +"""Tests for asyncio/timeouts.py""" + +import unittest +import time + +import asyncio + +from test.test_asyncio.utils import await_without_task + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + +class TimeoutTests(unittest.IsolatedAsyncioTestCase): + + async def test_timeout_basic(self): + with self.assertRaises(TimeoutError): + async with asyncio.timeout(0.01) as cm: + await asyncio.sleep(10) + self.assertTrue(cm.expired()) + + async def test_timeout_at_basic(self): + loop = asyncio.get_running_loop() + + with self.assertRaises(TimeoutError): + deadline = loop.time() + 0.01 + async with asyncio.timeout_at(deadline) as cm: + await asyncio.sleep(10) + self.assertTrue(cm.expired()) + self.assertEqual(deadline, cm.when()) + + async def test_nested_timeouts(self): + loop = asyncio.get_running_loop() + cancelled = False + with self.assertRaises(TimeoutError): + deadline = loop.time() + 0.01 + async with asyncio.timeout_at(deadline) as cm1: + # Only the topmost context manager should raise TimeoutError + try: + async with asyncio.timeout_at(deadline) as cm2: + await asyncio.sleep(10) + except asyncio.CancelledError: + cancelled = True + raise + self.assertTrue(cancelled) + self.assertTrue(cm1.expired()) + self.assertTrue(cm2.expired()) + + async def test_waiter_cancelled(self): + cancelled = False + with self.assertRaises(TimeoutError): + async with asyncio.timeout(0.01): + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + cancelled = True + raise + self.assertTrue(cancelled) + + async def test_timeout_not_called(self): + loop = asyncio.get_running_loop() + async with asyncio.timeout(10) as cm: + await asyncio.sleep(0.01) + t1 = loop.time() + + self.assertFalse(cm.expired()) + self.assertGreater(cm.when(), t1) + + async def test_timeout_disabled(self): + async with asyncio.timeout(None) as cm: + await asyncio.sleep(0.01) + + self.assertFalse(cm.expired()) + self.assertIsNone(cm.when()) + + async def test_timeout_at_disabled(self): + async with asyncio.timeout_at(None) as cm: + await asyncio.sleep(0.01) + + self.assertFalse(cm.expired()) + self.assertIsNone(cm.when()) + + async def test_timeout_zero(self): + loop = asyncio.get_running_loop() + t0 = loop.time() + with self.assertRaises(TimeoutError): + async with asyncio.timeout(0) as cm: + await asyncio.sleep(10) + t1 = loop.time() + self.assertTrue(cm.expired()) + self.assertTrue(t0 <= cm.when() <= t1) + + async def test_timeout_zero_sleep_zero(self): + loop = asyncio.get_running_loop() + t0 = loop.time() + with self.assertRaises(TimeoutError): + async with asyncio.timeout(0) as cm: + await asyncio.sleep(0) + t1 = loop.time() + self.assertTrue(cm.expired()) + self.assertTrue(t0 <= cm.when() <= t1) + + async def test_timeout_in_the_past_sleep_zero(self): + loop = asyncio.get_running_loop() + t0 = loop.time() + with self.assertRaises(TimeoutError): + async with asyncio.timeout(-11) as cm: + await asyncio.sleep(0) + t1 = loop.time() + self.assertTrue(cm.expired()) + self.assertTrue(t0 >= cm.when() <= t1) + + async def test_foreign_exception_passed(self): + with self.assertRaises(KeyError): + async with asyncio.timeout(0.01) as cm: + raise KeyError + self.assertFalse(cm.expired()) + + async def test_timeout_exception_context(self): + with self.assertRaises(TimeoutError) as cm: + async with asyncio.timeout(0.01): + try: + 1/0 + finally: + await asyncio.sleep(1) + e = cm.exception + # Expect TimeoutError caused by CancelledError raised during handling + # of ZeroDivisionError. + e2 = e.__cause__ + self.assertIsInstance(e2, asyncio.CancelledError) + self.assertIs(e.__context__, e2) + self.assertIsNone(e2.__cause__) + self.assertIsInstance(e2.__context__, ZeroDivisionError) + + # TODO: RUSTPYTHON + # AssertionError: CancelledError() is not an instance of + @unittest.expectedFailure + async def test_foreign_exception_on_timeout(self): + async def crash(): + try: + await asyncio.sleep(1) + finally: + 1/0 + with self.assertRaises(ZeroDivisionError) as cm: + async with asyncio.timeout(0.01): + await crash() + e = cm.exception + # Expect ZeroDivisionError raised during handling of TimeoutError + # caused by CancelledError. + self.assertIsNone(e.__cause__) + e2 = e.__context__ + self.assertIsInstance(e2, TimeoutError) + e3 = e2.__cause__ + self.assertIsInstance(e3, asyncio.CancelledError) + self.assertIs(e2.__context__, e3) + + # TODO: RUSTPYTHON + # AssertionError: CancelledError() is not an instance of + @unittest.expectedFailure + async def test_foreign_exception_on_timeout_2(self): + with self.assertRaises(ZeroDivisionError) as cm: + async with asyncio.timeout(0.01): + try: + try: + raise ValueError + finally: + await asyncio.sleep(1) + finally: + try: + raise KeyError + finally: + 1/0 + e = cm.exception + # Expect ZeroDivisionError raised during handling of KeyError + # raised during handling of TimeoutError caused by CancelledError. + self.assertIsNone(e.__cause__) + e2 = e.__context__ + self.assertIsInstance(e2, KeyError) + self.assertIsNone(e2.__cause__) + e3 = e2.__context__ + self.assertIsInstance(e3, TimeoutError) + e4 = e3.__cause__ + self.assertIsInstance(e4, asyncio.CancelledError) + self.assertIsNone(e4.__cause__) + self.assertIsInstance(e4.__context__, ValueError) + self.assertIs(e3.__context__, e4) + + async def test_foreign_cancel_doesnt_timeout_if_not_expired(self): + with self.assertRaises(asyncio.CancelledError): + async with asyncio.timeout(10) as cm: + asyncio.current_task().cancel() + await asyncio.sleep(10) + self.assertFalse(cm.expired()) + + async def test_outer_task_is_not_cancelled(self): + async def outer() -> None: + with self.assertRaises(TimeoutError): + async with asyncio.timeout(0.001): + await asyncio.sleep(10) + + task = asyncio.create_task(outer()) + await task + self.assertFalse(task.cancelled()) + self.assertTrue(task.done()) + + async def test_nested_timeouts_concurrent(self): + with self.assertRaises(TimeoutError): + async with asyncio.timeout(0.002): + with self.assertRaises(TimeoutError): + async with asyncio.timeout(0.1): + # Pretend we crunch some numbers. + time.sleep(0.01) + await asyncio.sleep(1) + + async def test_nested_timeouts_loop_busy(self): + # After the inner timeout is an expensive operation which should + # be stopped by the outer timeout. + loop = asyncio.get_running_loop() + # Disable a message about long running task + loop.slow_callback_duration = 10 + t0 = loop.time() + with self.assertRaises(TimeoutError): + async with asyncio.timeout(0.1): # (1) + with self.assertRaises(TimeoutError): + async with asyncio.timeout(0.01): # (2) + # Pretend the loop is busy for a while. + time.sleep(0.1) + await asyncio.sleep(1) + # TimeoutError was cought by (2) + await asyncio.sleep(10) # This sleep should be interrupted by (1) + t1 = loop.time() + self.assertTrue(t0 <= t1 <= t0 + 1) + + async def test_reschedule(self): + loop = asyncio.get_running_loop() + fut = loop.create_future() + deadline1 = loop.time() + 10 + deadline2 = deadline1 + 20 + + async def f(): + async with asyncio.timeout_at(deadline1) as cm: + fut.set_result(cm) + await asyncio.sleep(50) + + task = asyncio.create_task(f()) + cm = await fut + + self.assertEqual(cm.when(), deadline1) + cm.reschedule(deadline2) + self.assertEqual(cm.when(), deadline2) + cm.reschedule(None) + self.assertIsNone(cm.when()) + + task.cancel() + + with self.assertRaises(asyncio.CancelledError): + await task + self.assertFalse(cm.expired()) + + async def test_repr_active(self): + async with asyncio.timeout(10) as cm: + self.assertRegex(repr(cm), r"") + + async def test_repr_expired(self): + with self.assertRaises(TimeoutError): + async with asyncio.timeout(0.01) as cm: + await asyncio.sleep(10) + self.assertEqual(repr(cm), "") + + async def test_repr_finished(self): + async with asyncio.timeout(10) as cm: + await asyncio.sleep(0) + + self.assertEqual(repr(cm), "") + + async def test_repr_disabled(self): + async with asyncio.timeout(None) as cm: + self.assertEqual(repr(cm), r"") + + async def test_nested_timeout_in_finally(self): + with self.assertRaises(TimeoutError) as cm1: + async with asyncio.timeout(0.01): + try: + await asyncio.sleep(1) + finally: + with self.assertRaises(TimeoutError) as cm2: + async with asyncio.timeout(0.01): + await asyncio.sleep(10) + e1 = cm1.exception + # Expect TimeoutError caused by CancelledError. + e12 = e1.__cause__ + self.assertIsInstance(e12, asyncio.CancelledError) + self.assertIsNone(e12.__cause__) + self.assertIsNone(e12.__context__) + self.assertIs(e1.__context__, e12) + e2 = cm2.exception + # Expect TimeoutError caused by CancelledError raised during + # handling of other CancelledError (which is the same as in + # the above chain). + e22 = e2.__cause__ + self.assertIsInstance(e22, asyncio.CancelledError) + self.assertIsNone(e22.__cause__) + self.assertIs(e22.__context__, e12) + self.assertIs(e2.__context__, e22) + + async def test_timeout_after_cancellation(self): + try: + asyncio.current_task().cancel() + await asyncio.sleep(1) # work which will be cancelled + except asyncio.CancelledError: + pass + finally: + with self.assertRaises(TimeoutError) as cm: + async with asyncio.timeout(0.0): + await asyncio.sleep(1) # some cleanup + + async def test_cancel_in_timeout_after_cancellation(self): + try: + asyncio.current_task().cancel() + await asyncio.sleep(1) # work which will be cancelled + except asyncio.CancelledError: + pass + finally: + with self.assertRaises(asyncio.CancelledError): + async with asyncio.timeout(1.0): + asyncio.current_task().cancel() + await asyncio.sleep(2) # some cleanup + + async def test_timeout_already_entered(self): + async with asyncio.timeout(0.01) as cm: + with self.assertRaisesRegex(RuntimeError, "has already been entered"): + async with cm: + pass + + async def test_timeout_double_enter(self): + async with asyncio.timeout(0.01) as cm: + pass + with self.assertRaisesRegex(RuntimeError, "has already been entered"): + async with cm: + pass + + async def test_timeout_finished(self): + async with asyncio.timeout(0.01) as cm: + pass + with self.assertRaisesRegex(RuntimeError, "finished"): + cm.reschedule(0.02) + + async def test_timeout_expired(self): + with self.assertRaises(TimeoutError): + async with asyncio.timeout(0.01) as cm: + await asyncio.sleep(1) + with self.assertRaisesRegex(RuntimeError, "expired"): + cm.reschedule(0.02) + + async def test_timeout_expiring(self): + async with asyncio.timeout(0.01) as cm: + with self.assertRaises(asyncio.CancelledError): + await asyncio.sleep(1) + with self.assertRaisesRegex(RuntimeError, "expiring"): + cm.reschedule(0.02) + + async def test_timeout_not_entered(self): + cm = asyncio.timeout(0.01) + with self.assertRaisesRegex(RuntimeError, "has not been entered"): + cm.reschedule(0.02) + + async def test_timeout_without_task(self): + cm = asyncio.timeout(0.01) + with self.assertRaisesRegex(RuntimeError, "task"): + await await_without_task(cm.__aenter__()) + with self.assertRaisesRegex(RuntimeError, "has not been entered"): + cm.reschedule(0.02) + + # TODO: RUSTPYTHON + # AssertionError: CancelledError() is not an instance of + @unittest.expectedFailure + async def test_timeout_taskgroup(self): + async def task(): + try: + await asyncio.sleep(2) # Will be interrupted after 0.01 second + finally: + 1/0 # Crash in cleanup + + with self.assertRaises(ExceptionGroup) as cm: + async with asyncio.timeout(0.01): + async with asyncio.TaskGroup() as tg: + tg.create_task(task()) + try: + raise ValueError + finally: + await asyncio.sleep(1) + eg = cm.exception + # Expect ExceptionGroup raised during handling of TimeoutError caused + # by CancelledError raised during handling of ValueError. + self.assertIsNone(eg.__cause__) + e_1 = eg.__context__ + self.assertIsInstance(e_1, TimeoutError) + e_2 = e_1.__cause__ + self.assertIsInstance(e_2, asyncio.CancelledError) + self.assertIsNone(e_2.__cause__) + self.assertIsInstance(e_2.__context__, ValueError) + self.assertIs(e_1.__context__, e_2) + + self.assertEqual(len(eg.exceptions), 1, eg) + e1 = eg.exceptions[0] + # Expect ZeroDivisionError raised during handling of TimeoutError + # caused by CancelledError (it is a different CancelledError). + self.assertIsInstance(e1, ZeroDivisionError) + self.assertIsNone(e1.__cause__) + e2 = e1.__context__ + self.assertIsInstance(e2, TimeoutError) + e3 = e2.__cause__ + self.assertIsInstance(e3, asyncio.CancelledError) + self.assertIsNone(e3.__context__) + self.assertIsNone(e3.__cause__) + self.assertIs(e2.__context__, e3) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_transports.py b/Lib/test/test_asyncio/test_transports.py new file mode 100644 index 00000000000..bbdb218efaa --- /dev/null +++ b/Lib/test/test_asyncio/test_transports.py @@ -0,0 +1,103 @@ +"""Tests for transports.py.""" + +import unittest +from unittest import mock + +import asyncio +from asyncio import transports + + +def tearDownModule(): + # not needed for the test file but added for uniformness with all other + # asyncio test files for the sake of unified cleanup + asyncio.set_event_loop_policy(None) + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = asyncio.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = asyncio.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + writer = mock.Mock() + + class MyTransport(asyncio.Transport): + def write(self, data): + writer(data) + + transport = MyTransport() + + transport.writelines([b'line1', + bytearray(b'line2'), + memoryview(b'line3')]) + self.assertEqual(1, writer.call_count) + writer.assert_called_with(b'line1line2line3') + + def test_not_implemented(self): + transport = asyncio.Transport() + + self.assertRaises(NotImplementedError, + transport.set_write_buffer_limits) + self.assertRaises(NotImplementedError, transport.get_write_buffer_size) + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause_reading) + self.assertRaises(NotImplementedError, transport.resume_reading) + self.assertRaises(NotImplementedError, transport.is_reading) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) + + def test_dgram_not_implemented(self): + transport = asyncio.DatagramTransport() + + self.assertRaises(NotImplementedError, transport.sendto, 'data') + self.assertRaises(NotImplementedError, transport.abort) + + def test_subprocess_transport_not_implemented(self): + transport = asyncio.SubprocessTransport() + + self.assertRaises(NotImplementedError, transport.get_pid) + self.assertRaises(NotImplementedError, transport.get_returncode) + self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1) + self.assertRaises(NotImplementedError, transport.send_signal, 1) + self.assertRaises(NotImplementedError, transport.terminate) + self.assertRaises(NotImplementedError, transport.kill) + + def test_flowcontrol_mixin_set_write_limits(self): + + class MyTransport(transports._FlowControlMixin, + transports.Transport): + + def get_write_buffer_size(self): + return 512 + + loop = mock.Mock() + transport = MyTransport(loop=loop) + transport._protocol = mock.Mock() + + self.assertFalse(transport._protocol_paused) + + with self.assertRaisesRegex(ValueError, 'high.*must be >= low'): + transport.set_write_buffer_limits(high=0, low=1) + + transport.set_write_buffer_limits(high=1024, low=128) + self.assertFalse(transport._protocol_paused) + self.assertEqual(transport.get_write_buffer_limits(), (128, 1024)) + + transport.set_write_buffer_limits(high=256, low=128) + self.assertTrue(transport._protocol_paused) + self.assertEqual(transport.get_write_buffer_limits(), (128, 256)) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_unix_events.py b/Lib/test/test_asyncio/test_unix_events.py new file mode 100644 index 00000000000..9ed74cc1da0 --- /dev/null +++ b/Lib/test/test_asyncio/test_unix_events.py @@ -0,0 +1,2002 @@ +"""Tests for unix_events.py.""" + +import contextlib +import errno +import io +import multiprocessing +from multiprocessing.util import _cleanup_tests as multiprocessing_cleanup_tests +import os +import signal +import socket +import stat +import sys +import threading +import time +import unittest +from unittest import mock +import warnings + +from test import support +from test.support import os_helper +from test.support import socket_helper +from test.support import wait_process +from test.support import hashlib_helper + +if sys.platform == 'win32': + raise unittest.SkipTest('UNIX only') + + +import asyncio +from asyncio import log +from asyncio import unix_events +from test.test_asyncio import utils as test_utils + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +MOCK_ANY = mock.ANY + + +def EXITCODE(exitcode): + return 32768 + exitcode + + +def SIGNAL(signum): + if not 1 <= signum <= 68: + raise AssertionError(f'invalid signum {signum}') + return 32768 - signum + + +def close_pipe_transport(transport): + # Don't call transport.close() because the event loop and the selector + # are mocked + if transport._pipe is None: + return + transport._pipe.close() + transport._pipe = None + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopSignalTests(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = asyncio.SelectorEventLoop() + self.set_event_loop(self.loop) + + def test_check_signal(self): + self.assertRaises( + TypeError, self.loop._check_signal, '1') + self.assertRaises( + ValueError, self.loop._check_signal, signal.NSIG + 1) + + def test_handle_signal_no_handler(self): + self.loop._handle_signal(signal.NSIG + 1) + + def test_handle_signal_cancelled_handler(self): + h = asyncio.Handle(mock.Mock(), (), + loop=mock.Mock()) + h.cancel() + self.loop._signal_handlers[signal.NSIG + 1] = h + self.loop.remove_signal_handler = mock.Mock() + self.loop._handle_signal(signal.NSIG + 1) + self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1) + + @mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.valid_signals = signal.valid_signals + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler_coroutine_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + async def simple_coroutine(): + pass + + # callback must not be a coroutine function + coro_func = simple_coroutine + coro_obj = coro_func() + self.addCleanup(coro_obj.close) + for func in (coro_func, coro_obj): + self.assertRaisesRegex( + TypeError, 'coroutines cannot be used with add_signal_handler', + self.loop.add_signal_handler, + signal.SIGINT, func) + + @mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.valid_signals = signal.valid_signals + + cb = lambda: True + self.loop.add_signal_handler(signal.SIGHUP, cb) + h = self.loop._signal_handlers.get(signal.SIGHUP) + self.assertIsInstance(h, asyncio.Handle) + self.assertEqual(h._callback, cb) + + @mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.valid_signals = signal.valid_signals + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.base_events.logger') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.valid_signals = signal.valid_signals + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.base_events.logger') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + m_signal.valid_signals = signal.valid_signals + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.valid_signals = signal.valid_signals + + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + m_signal.valid_signals = signal.valid_signals + + self.loop.add_signal_handler(signal.SIGINT, lambda: True) + self.loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.base_events.logger') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.valid_signals = signal.valid_signals + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.valid_signals = signal.valid_signals + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.loop.remove_signal_handler, signal.SIGHUP) + + @mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.valid_signals = signal.valid_signals + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP) + + @mock.patch('asyncio.unix_events.signal') + def test_close(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.valid_signals = signal.valid_signals + + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.loop.add_signal_handler(signal.SIGCHLD, lambda: True) + + self.assertEqual(len(self.loop._signal_handlers), 2) + + m_signal.set_wakeup_fd.reset_mock() + + self.loop.close() + + self.assertEqual(len(self.loop._signal_handlers), 0) + m_signal.set_wakeup_fd.assert_called_once_with(-1) + + @mock.patch('asyncio.unix_events.sys') + @mock.patch('asyncio.unix_events.signal') + def test_close_on_finalizing(self, m_signal, m_sys): + m_signal.NSIG = signal.NSIG + m_signal.valid_signals = signal.valid_signals + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertEqual(len(self.loop._signal_handlers), 1) + m_sys.is_finalizing.return_value = True + m_signal.signal.reset_mock() + + with self.assertWarnsRegex(ResourceWarning, + "skipping signal handlers removal"): + self.loop.close() + + self.assertEqual(len(self.loop._signal_handlers), 0) + self.assertFalse(m_signal.signal.called) + + +@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), + 'UNIX Sockets are not supported') +class SelectorEventLoopUnixSocketTests(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = asyncio.SelectorEventLoop() + self.set_event_loop(self.loop) + + @socket_helper.skip_unless_bind_unix_socket + def test_create_unix_server_existing_path_sock(self): + with test_utils.unix_socket_path() as path: + sock = socket.socket(socket.AF_UNIX) + sock.bind(path) + sock.listen(1) + sock.close() + + coro = self.loop.create_unix_server(lambda: None, path) + srv = self.loop.run_until_complete(coro) + srv.close() + self.loop.run_until_complete(srv.wait_closed()) + + @socket_helper.skip_unless_bind_unix_socket + def test_create_unix_server_pathlike(self): + with test_utils.unix_socket_path() as path: + path = os_helper.FakePath(path) + srv_coro = self.loop.create_unix_server(lambda: None, path) + srv = self.loop.run_until_complete(srv_coro) + srv.close() + self.loop.run_until_complete(srv.wait_closed()) + + def test_create_unix_connection_pathlike(self): + with test_utils.unix_socket_path() as path: + path = os_helper.FakePath(path) + coro = self.loop.create_unix_connection(lambda: None, path) + with self.assertRaises(FileNotFoundError): + # If path-like object weren't supported, the exception would be + # different. + self.loop.run_until_complete(coro) + + def test_create_unix_server_existing_path_nonsock(self): + path = test_utils.gen_unix_socket_path() + self.addCleanup(os_helper.unlink, path) + # create the file + open(path, "wb").close() + + coro = self.loop.create_unix_server(lambda: None, path) + with self.assertRaisesRegex(OSError, + 'Address.*is already in use'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_ssl_bool(self): + coro = self.loop.create_unix_server(lambda: None, path='spam', + ssl=True) + with self.assertRaisesRegex(TypeError, + 'ssl argument must be an SSLContext'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_nopath_nosock(self): + coro = self.loop.create_unix_server(lambda: None, path=None) + with self.assertRaisesRegex(ValueError, + 'path was not specified, and no sock'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_path_inetsock(self): + sock = socket.socket() + with sock: + coro = self.loop.create_unix_server(lambda: None, path=None, + sock=sock) + with self.assertRaisesRegex(ValueError, + 'A UNIX Domain Stream.*was expected'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_path_dgram(self): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + with sock: + coro = self.loop.create_unix_server(lambda: None, path=None, + sock=sock) + with self.assertRaisesRegex(ValueError, + 'A UNIX Domain Stream.*was expected'): + self.loop.run_until_complete(coro) + + @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'), + 'no socket.SOCK_NONBLOCK (linux only)') + @socket_helper.skip_unless_bind_unix_socket + @unittest.skip('TODO: RUSTPYTHON') + # ValueError: A UNIX Domain Stream Socket was expected, got + def test_create_unix_server_path_stream_bittype(self): + fn = test_utils.gen_unix_socket_path() + self.addCleanup(os_helper.unlink, fn) + + sock = socket.socket(socket.AF_UNIX, + socket.SOCK_STREAM | socket.SOCK_NONBLOCK) + with sock: + sock.bind(fn) + coro = self.loop.create_unix_server(lambda: None, path=None, + sock=sock) + srv = self.loop.run_until_complete(coro) + srv.close() + self.loop.run_until_complete(srv.wait_closed()) + + def test_create_unix_server_ssl_timeout_with_plain_sock(self): + coro = self.loop.create_unix_server(lambda: None, path='spam', + ssl_handshake_timeout=1) + with self.assertRaisesRegex( + ValueError, + 'ssl_handshake_timeout is only meaningful with ssl'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_path_inetsock(self): + sock = socket.socket() + with sock: + coro = self.loop.create_unix_connection(lambda: None, + sock=sock) + with self.assertRaisesRegex(ValueError, + 'A UNIX Domain Stream.*was expected'): + self.loop.run_until_complete(coro) + + @mock.patch('asyncio.unix_events.socket') + def test_create_unix_server_bind_error(self, m_socket): + # Ensure that the socket is closed on any bind error + sock = mock.Mock() + m_socket.socket.return_value = sock + + sock.bind.side_effect = OSError + coro = self.loop.create_unix_server(lambda: None, path="/test") + with self.assertRaises(OSError): + self.loop.run_until_complete(coro) + self.assertTrue(sock.close.called) + + sock.bind.side_effect = MemoryError + coro = self.loop.create_unix_server(lambda: None, path="/test") + with self.assertRaises(MemoryError): + self.loop.run_until_complete(coro) + self.assertTrue(sock.close.called) + + def test_create_unix_connection_path_sock(self): + coro = self.loop.create_unix_connection( + lambda: None, os.devnull, sock=object()) + with self.assertRaisesRegex(ValueError, 'path and sock can not be'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_nopath_nosock(self): + coro = self.loop.create_unix_connection( + lambda: None, None) + with self.assertRaisesRegex(ValueError, + 'no path and sock were specified'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_nossl_serverhost(self): + coro = self.loop.create_unix_connection( + lambda: None, os.devnull, server_hostname='spam') + with self.assertRaisesRegex(ValueError, + 'server_hostname is only meaningful'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_ssl_noserverhost(self): + coro = self.loop.create_unix_connection( + lambda: None, os.devnull, ssl=True) + + with self.assertRaisesRegex( + ValueError, 'you have to pass server_hostname when using ssl'): + + self.loop.run_until_complete(coro) + + def test_create_unix_connection_ssl_timeout_with_plain_sock(self): + coro = self.loop.create_unix_connection(lambda: None, path='spam', + ssl_handshake_timeout=1) + with self.assertRaisesRegex( + ValueError, + 'ssl_handshake_timeout is only meaningful with ssl'): + self.loop.run_until_complete(coro) + + +@unittest.skipUnless(hasattr(os, 'sendfile'), + 'sendfile is not supported') +class SelectorEventLoopUnixSockSendfileTests(test_utils.TestCase): + DATA = b"12345abcde" * 16 * 1024 # 160 KiB + + class MyProto(asyncio.Protocol): + + def __init__(self, loop): + self.started = False + self.closed = False + self.data = bytearray() + self.fut = loop.create_future() + self.transport = None + self._ready = loop.create_future() + + def connection_made(self, transport): + self.started = True + self.transport = transport + self._ready.set_result(None) + + def data_received(self, data): + self.data.extend(data) + + def connection_lost(self, exc): + self.closed = True + self.fut.set_result(None) + + async def wait_closed(self): + await self.fut + + @classmethod + def setUpClass(cls): + with open(os_helper.TESTFN, 'wb') as fp: + fp.write(cls.DATA) + super().setUpClass() + + @classmethod + def tearDownClass(cls): + os_helper.unlink(os_helper.TESTFN) + super().tearDownClass() + + def setUp(self): + self.loop = asyncio.new_event_loop() + self.set_event_loop(self.loop) + self.file = open(os_helper.TESTFN, 'rb') + self.addCleanup(self.file.close) + super().setUp() + + def make_socket(self, cleanup=True): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setblocking(False) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024) + if cleanup: + self.addCleanup(sock.close) + return sock + + def run_loop(self, coro): + return self.loop.run_until_complete(coro) + + def prepare(self): + sock = self.make_socket() + proto = self.MyProto(self.loop) + port = socket_helper.find_unused_port() + srv_sock = self.make_socket(cleanup=False) + srv_sock.bind((socket_helper.HOST, port)) + server = self.run_loop(self.loop.create_server( + lambda: proto, sock=srv_sock)) + self.run_loop(self.loop.sock_connect(sock, (socket_helper.HOST, port))) + self.run_loop(proto._ready) + + def cleanup(): + proto.transport.close() + self.run_loop(proto.wait_closed()) + + server.close() + self.run_loop(server.wait_closed()) + + self.addCleanup(cleanup) + + return sock, proto + + def test_sock_sendfile_not_available(self): + sock, proto = self.prepare() + with mock.patch('asyncio.unix_events.os', spec=[]): + with self.assertRaisesRegex(asyncio.SendfileNotAvailableError, + "os[.]sendfile[(][)] is not available"): + self.run_loop(self.loop._sock_sendfile_native(sock, self.file, + 0, None)) + self.assertEqual(self.file.tell(), 0) + + def test_sock_sendfile_not_a_file(self): + sock, proto = self.prepare() + f = object() + with self.assertRaisesRegex(asyncio.SendfileNotAvailableError, + "not a regular file"): + self.run_loop(self.loop._sock_sendfile_native(sock, f, + 0, None)) + self.assertEqual(self.file.tell(), 0) + + def test_sock_sendfile_iobuffer(self): + sock, proto = self.prepare() + f = io.BytesIO() + with self.assertRaisesRegex(asyncio.SendfileNotAvailableError, + "not a regular file"): + self.run_loop(self.loop._sock_sendfile_native(sock, f, + 0, None)) + self.assertEqual(self.file.tell(), 0) + + def test_sock_sendfile_not_regular_file(self): + sock, proto = self.prepare() + f = mock.Mock() + f.fileno.return_value = -1 + with self.assertRaisesRegex(asyncio.SendfileNotAvailableError, + "not a regular file"): + self.run_loop(self.loop._sock_sendfile_native(sock, f, + 0, None)) + self.assertEqual(self.file.tell(), 0) + + def test_sock_sendfile_cancel1(self): + sock, proto = self.prepare() + + fut = self.loop.create_future() + fileno = self.file.fileno() + self.loop._sock_sendfile_native_impl(fut, None, sock, fileno, + 0, None, len(self.DATA), 0) + fut.cancel() + with contextlib.suppress(asyncio.CancelledError): + self.run_loop(fut) + with self.assertRaises(KeyError): + self.loop._selector.get_key(sock) + + def test_sock_sendfile_cancel2(self): + sock, proto = self.prepare() + + fut = self.loop.create_future() + fileno = self.file.fileno() + self.loop._sock_sendfile_native_impl(fut, None, sock, fileno, + 0, None, len(self.DATA), 0) + fut.cancel() + self.loop._sock_sendfile_native_impl(fut, sock.fileno(), sock, fileno, + 0, None, len(self.DATA), 0) + with self.assertRaises(KeyError): + self.loop._selector.get_key(sock) + + def test_sock_sendfile_blocking_error(self): + sock, proto = self.prepare() + + fileno = self.file.fileno() + fut = mock.Mock() + fut.cancelled.return_value = False + with mock.patch('os.sendfile', side_effect=BlockingIOError()): + self.loop._sock_sendfile_native_impl(fut, None, sock, fileno, + 0, None, len(self.DATA), 0) + key = self.loop._selector.get_key(sock) + self.assertIsNotNone(key) + fut.add_done_callback.assert_called_once_with(mock.ANY) + + def test_sock_sendfile_os_error_first_call(self): + sock, proto = self.prepare() + + fileno = self.file.fileno() + fut = self.loop.create_future() + with mock.patch('os.sendfile', side_effect=OSError()): + self.loop._sock_sendfile_native_impl(fut, None, sock, fileno, + 0, None, len(self.DATA), 0) + with self.assertRaises(KeyError): + self.loop._selector.get_key(sock) + exc = fut.exception() + self.assertIsInstance(exc, asyncio.SendfileNotAvailableError) + self.assertEqual(0, self.file.tell()) + + def test_sock_sendfile_os_error_next_call(self): + sock, proto = self.prepare() + + fileno = self.file.fileno() + fut = self.loop.create_future() + err = OSError() + with mock.patch('os.sendfile', side_effect=err): + self.loop._sock_sendfile_native_impl(fut, sock.fileno(), + sock, fileno, + 1000, None, len(self.DATA), + 1000) + with self.assertRaises(KeyError): + self.loop._selector.get_key(sock) + exc = fut.exception() + self.assertIs(exc, err) + self.assertEqual(1000, self.file.tell()) + + def test_sock_sendfile_exception(self): + sock, proto = self.prepare() + + fileno = self.file.fileno() + fut = self.loop.create_future() + err = asyncio.SendfileNotAvailableError() + with mock.patch('os.sendfile', side_effect=err): + self.loop._sock_sendfile_native_impl(fut, sock.fileno(), + sock, fileno, + 1000, None, len(self.DATA), + 1000) + with self.assertRaises(KeyError): + self.loop._selector.get_key(sock) + exc = fut.exception() + self.assertIs(exc, err) + self.assertEqual(1000, self.file.tell()) + + +class UnixReadPipeTransportTests(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = self.new_test_loop() + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + self.pipe = mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + blocking_patcher = mock.patch('os.set_blocking') + blocking_patcher.start() + self.addCleanup(blocking_patcher.stop) + + fstat_patcher = mock.patch('os.fstat') + m_fstat = fstat_patcher.start() + st = mock.Mock() + st.st_mode = stat.S_IFIFO + m_fstat.return_value = st + self.addCleanup(fstat_patcher.stop) + + def read_pipe_transport(self, waiter=None): + transport = unix_events._UnixReadPipeTransport(self.loop, self.pipe, + self.protocol, + waiter=waiter) + self.addCleanup(close_pipe_transport, transport) + return transport + + def test_ctor(self): + waiter = self.loop.create_future() + tr = self.read_pipe_transport(waiter=waiter) + self.loop.run_until_complete(waiter) + + self.protocol.connection_made.assert_called_with(tr) + self.loop.assert_reader(5, tr._read_ready) + self.assertIsNone(waiter.result()) + + @mock.patch('os.read') + def test__read_ready(self, m_read): + tr = self.read_pipe_transport() + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @mock.patch('os.read') + def test__read_ready_eof(self, m_read): + tr = self.read_pipe_transport() + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.eof_received.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + @mock.patch('os.read') + def test__read_ready_blocked(self, m_read): + tr = self.read_pipe_transport() + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.data_received.called) + + @mock.patch('asyncio.log.logger.error') + @mock.patch('os.read') + def test__read_ready_error(self, m_read, m_logexc): + tr = self.read_pipe_transport() + err = OSError() + m_read.side_effect = err + tr._close = mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with( + test_utils.MockPattern( + 'Fatal read error on pipe transport' + '\nprotocol:.*\ntransport:.*'), + exc_info=(OSError, MOCK_ANY, MOCK_ANY)) + + @mock.patch('os.read') + def test_pause_reading(self, m_read): + tr = self.read_pipe_transport() + m = mock.Mock() + self.loop.add_reader(5, m) + tr.pause_reading() + self.assertFalse(self.loop.readers) + + @mock.patch('os.read') + def test_resume_reading(self, m_read): + tr = self.read_pipe_transport() + tr.pause_reading() + tr.resume_reading() + self.loop.assert_reader(5, tr._read_ready) + + @mock.patch('os.read') + def test_close(self, m_read): + tr = self.read_pipe_transport() + tr._close = mock.Mock() + tr.close() + tr._close.assert_called_with(None) + + @mock.patch('os.read') + def test_close_already_closing(self, m_read): + tr = self.read_pipe_transport() + tr._closing = True + tr._close = mock.Mock() + tr.close() + self.assertFalse(tr._close.called) + + @mock.patch('os.read') + def test__close(self, m_read): + tr = self.read_pipe_transport() + err = object() + tr._close(err) + self.assertTrue(tr.is_closing()) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + def test__call_connection_lost(self): + tr = self.read_pipe_transport() + self.assertIsNotNone(tr._protocol) + self.assertIsNotNone(tr._loop) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertIsNone(tr._loop) + + def test__call_connection_lost_with_err(self): + tr = self.read_pipe_transport() + self.assertIsNotNone(tr._protocol) + self.assertIsNotNone(tr._loop) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertIsNone(tr._loop) + + def test_pause_reading_on_closed_pipe(self): + tr = self.read_pipe_transport() + tr.close() + test_utils.run_briefly(self.loop) + self.assertIsNone(tr._loop) + tr.pause_reading() + + def test_pause_reading_on_paused_pipe(self): + tr = self.read_pipe_transport() + tr.pause_reading() + # the second call should do nothing + tr.pause_reading() + + def test_resume_reading_on_closed_pipe(self): + tr = self.read_pipe_transport() + tr.close() + test_utils.run_briefly(self.loop) + self.assertIsNone(tr._loop) + tr.resume_reading() + + def test_resume_reading_on_paused_pipe(self): + tr = self.read_pipe_transport() + # the pipe is not paused + # resuming should do nothing + tr.resume_reading() + + +class UnixWritePipeTransportTests(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = self.new_test_loop() + self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol) + self.pipe = mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + blocking_patcher = mock.patch('os.set_blocking') + blocking_patcher.start() + self.addCleanup(blocking_patcher.stop) + + fstat_patcher = mock.patch('os.fstat') + m_fstat = fstat_patcher.start() + st = mock.Mock() + st.st_mode = stat.S_IFSOCK + m_fstat.return_value = st + self.addCleanup(fstat_patcher.stop) + + def write_pipe_transport(self, waiter=None): + transport = unix_events._UnixWritePipeTransport(self.loop, self.pipe, + self.protocol, + waiter=waiter) + self.addCleanup(close_pipe_transport, transport) + return transport + + def test_ctor(self): + waiter = self.loop.create_future() + tr = self.write_pipe_transport(waiter=waiter) + self.loop.run_until_complete(waiter) + + self.protocol.connection_made.assert_called_with(tr) + self.loop.assert_reader(5, tr._read_ready) + self.assertEqual(None, waiter.result()) + + def test_can_write_eof(self): + tr = self.write_pipe_transport() + self.assertTrue(tr.can_write_eof()) + + @mock.patch('os.write') + def test_write(self, m_write): + tr = self.write_pipe_transport() + m_write.return_value = 4 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual(bytearray(), tr._buffer) + + @mock.patch('os.write') + def test_write_no_data(self, m_write): + tr = self.write_pipe_transport() + tr.write(b'') + self.assertFalse(m_write.called) + self.assertFalse(self.loop.writers) + self.assertEqual(bytearray(b''), tr._buffer) + + @mock.patch('os.write') + def test_write_partial(self, m_write): + tr = self.write_pipe_transport() + m_write.return_value = 2 + tr.write(b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual(bytearray(b'ta'), tr._buffer) + + @mock.patch('os.write') + def test_write_buffer(self, m_write): + tr = self.write_pipe_transport() + self.loop.add_writer(5, tr._write_ready) + tr._buffer = bytearray(b'previous') + tr.write(b'data') + self.assertFalse(m_write.called) + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual(bytearray(b'previousdata'), tr._buffer) + + @mock.patch('os.write') + def test_write_again(self, m_write): + tr = self.write_pipe_transport() + m_write.side_effect = BlockingIOError() + tr.write(b'data') + m_write.assert_called_with(5, bytearray(b'data')) + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual(bytearray(b'data'), tr._buffer) + + @mock.patch('asyncio.unix_events.logger') + @mock.patch('os.write') + def test_write_err(self, m_write, m_log): + tr = self.write_pipe_transport() + err = OSError() + m_write.side_effect = err + tr._fatal_error = mock.Mock() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual(bytearray(), tr._buffer) + tr._fatal_error.assert_called_with( + err, + 'Fatal write error on pipe transport') + self.assertEqual(1, tr._conn_lost) + + tr.write(b'data') + self.assertEqual(2, tr._conn_lost) + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + # This is a bit overspecified. :-( + m_log.warning.assert_called_with( + 'pipe closed by peer or os.write(pipe, data) raised exception.') + tr.close() + + @mock.patch('os.write') + def test_write_close(self, m_write): + tr = self.write_pipe_transport() + tr._read_ready() # pipe was closed by peer + + tr.write(b'data') + self.assertEqual(tr._conn_lost, 1) + tr.write(b'data') + self.assertEqual(tr._conn_lost, 2) + + def test__read_ready(self): + tr = self.write_pipe_transport() + tr._read_ready() + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertTrue(tr.is_closing()) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + @mock.patch('os.write') + def test__write_ready(self, m_write): + tr = self.write_pipe_transport() + self.loop.add_writer(5, tr._write_ready) + tr._buffer = bytearray(b'data') + m_write.return_value = 4 + tr._write_ready() + self.assertFalse(self.loop.writers) + self.assertEqual(bytearray(), tr._buffer) + + @mock.patch('os.write') + def test__write_ready_partial(self, m_write): + tr = self.write_pipe_transport() + self.loop.add_writer(5, tr._write_ready) + tr._buffer = bytearray(b'data') + m_write.return_value = 3 + tr._write_ready() + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual(bytearray(b'a'), tr._buffer) + + @mock.patch('os.write') + def test__write_ready_again(self, m_write): + tr = self.write_pipe_transport() + self.loop.add_writer(5, tr._write_ready) + tr._buffer = bytearray(b'data') + m_write.side_effect = BlockingIOError() + tr._write_ready() + m_write.assert_called_with(5, bytearray(b'data')) + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual(bytearray(b'data'), tr._buffer) + + @mock.patch('os.write') + def test__write_ready_empty(self, m_write): + tr = self.write_pipe_transport() + self.loop.add_writer(5, tr._write_ready) + tr._buffer = bytearray(b'data') + m_write.return_value = 0 + tr._write_ready() + m_write.assert_called_with(5, bytearray(b'data')) + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual(bytearray(b'data'), tr._buffer) + + @mock.patch('asyncio.log.logger.error') + @mock.patch('os.write') + def test__write_ready_err(self, m_write, m_logexc): + tr = self.write_pipe_transport() + self.loop.add_writer(5, tr._write_ready) + tr._buffer = bytearray(b'data') + m_write.side_effect = err = OSError() + tr._write_ready() + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual(bytearray(), tr._buffer) + self.assertTrue(tr.is_closing()) + m_logexc.assert_not_called() + self.assertEqual(1, tr._conn_lost) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + @mock.patch('os.write') + def test__write_ready_closing(self, m_write): + tr = self.write_pipe_transport() + self.loop.add_writer(5, tr._write_ready) + tr._closing = True + tr._buffer = bytearray(b'data') + m_write.return_value = 4 + tr._write_ready() + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual(bytearray(), tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + self.pipe.close.assert_called_with() + + @mock.patch('os.write') + def test_abort(self, m_write): + tr = self.write_pipe_transport() + self.loop.add_writer(5, tr._write_ready) + self.loop.add_reader(5, tr._read_ready) + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr.is_closing()) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test__call_connection_lost(self): + tr = self.write_pipe_transport() + self.assertIsNotNone(tr._protocol) + self.assertIsNotNone(tr._loop) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertIsNone(tr._loop) + + def test__call_connection_lost_with_err(self): + tr = self.write_pipe_transport() + self.assertIsNotNone(tr._protocol) + self.assertIsNotNone(tr._loop) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertIsNone(tr._loop) + + def test_close(self): + tr = self.write_pipe_transport() + tr.write_eof = mock.Mock() + tr.close() + tr.write_eof.assert_called_with() + + # closing the transport twice must not fail + tr.close() + + def test_close_closing(self): + tr = self.write_pipe_transport() + tr.write_eof = mock.Mock() + tr._closing = True + tr.close() + self.assertFalse(tr.write_eof.called) + + def test_write_eof(self): + tr = self.write_pipe_transport() + tr.write_eof() + self.assertTrue(tr.is_closing()) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_write_eof_pending(self): + tr = self.write_pipe_transport() + tr._buffer = [b'data'] + tr.write_eof() + self.assertTrue(tr.is_closing()) + self.assertFalse(self.protocol.connection_lost.called) + + +class AbstractChildWatcherTests(unittest.TestCase): + + def test_warns_on_subclassing(self): + with self.assertWarns(DeprecationWarning): + class MyWatcher(asyncio.AbstractChildWatcher): + pass + + def test_not_implemented(self): + f = mock.Mock() + watcher = asyncio.AbstractChildWatcher() + self.assertRaises( + NotImplementedError, watcher.add_child_handler, f, f) + self.assertRaises( + NotImplementedError, watcher.remove_child_handler, f) + self.assertRaises( + NotImplementedError, watcher.attach_loop, f) + self.assertRaises( + NotImplementedError, watcher.close) + self.assertRaises( + NotImplementedError, watcher.is_active) + self.assertRaises( + NotImplementedError, watcher.__enter__) + self.assertRaises( + NotImplementedError, watcher.__exit__, f, f, f) + + +class BaseChildWatcherTests(unittest.TestCase): + + def test_not_implemented(self): + f = mock.Mock() + watcher = unix_events.BaseChildWatcher() + self.assertRaises( + NotImplementedError, watcher._do_waitpid, f) + + +class ChildWatcherTestsMixin: + + ignore_warnings = mock.patch.object(log.logger, "warning") + + def setUp(self): + super().setUp() + self.loop = self.new_test_loop() + self.running = False + self.zombies = {} + + with mock.patch.object( + self.loop, "add_signal_handler") as self.m_add_signal_handler: + self.watcher = self.create_watcher() + self.watcher.attach_loop(self.loop) + + def waitpid(self, pid, flags): + if isinstance(self.watcher, asyncio.SafeChildWatcher) or pid != -1: + self.assertGreater(pid, 0) + try: + if pid < 0: + return self.zombies.popitem() + else: + return pid, self.zombies.pop(pid) + except KeyError: + pass + if self.running: + return 0, 0 + else: + raise ChildProcessError() + + def add_zombie(self, pid, status): + self.zombies[pid] = status + + def waitstatus_to_exitcode(self, status): + if status > 32768: + return status - 32768 + elif 32700 < status < 32768: + return status - 32768 + else: + return status + + def test_create_watcher(self): + self.m_add_signal_handler.assert_called_once_with( + signal.SIGCHLD, self.watcher._sig_chld) + + def waitpid_mocks(func): + def wrapped_func(self): + def patch(target, wrapper): + return mock.patch(target, wraps=wrapper, + new_callable=mock.Mock) + + with patch('asyncio.unix_events.waitstatus_to_exitcode', self.waitstatus_to_exitcode), \ + patch('os.waitpid', self.waitpid) as m_waitpid: + func(self, m_waitpid) + return wrapped_func + + @waitpid_mocks + def test_sigchld(self, m_waitpid): + # register a child + callback = mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(42, callback, 9, 10, 14) + + self.assertFalse(callback.called) + + # child is running + self.watcher._sig_chld() + + self.assertFalse(callback.called) + + # child terminates (returncode 12) + self.running = False + self.add_zombie(42, EXITCODE(12)) + self.watcher._sig_chld() + + callback.assert_called_once_with(42, 12, 9, 10, 14) + + callback.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(42, EXITCODE(13)) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + + # sigchld called again + self.zombies.clear() + self.watcher._sig_chld() + + self.assertFalse(callback.called) + + @waitpid_mocks + def test_sigchld_two_children(self, m_waitpid): + callback1 = mock.Mock() + callback2 = mock.Mock() + + # register child 1 + with self.watcher: + self.running = True + self.watcher.add_child_handler(43, callback1, 7, 8) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + + # register child 2 + with self.watcher: + self.watcher.add_child_handler(44, callback2, 147, 18) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + + # children are running + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + + # child 1 terminates (signal 3) + self.add_zombie(43, SIGNAL(3)) + self.watcher._sig_chld() + + callback1.assert_called_once_with(43, -3, 7, 8) + self.assertFalse(callback2.called) + + callback1.reset_mock() + + # child 2 still running + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + + # child 2 terminates (code 108) + self.add_zombie(44, EXITCODE(108)) + self.running = False + self.watcher._sig_chld() + + callback2.assert_called_once_with(44, 108, 147, 18) + self.assertFalse(callback1.called) + + callback2.reset_mock() + + # ensure that the children are effectively reaped + self.add_zombie(43, EXITCODE(14)) + self.add_zombie(44, EXITCODE(15)) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + + # sigchld called again + self.zombies.clear() + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + + @waitpid_mocks + def test_sigchld_two_children_terminating_together(self, m_waitpid): + callback1 = mock.Mock() + callback2 = mock.Mock() + + # register child 1 + with self.watcher: + self.running = True + self.watcher.add_child_handler(45, callback1, 17, 8) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + + # register child 2 + with self.watcher: + self.watcher.add_child_handler(46, callback2, 1147, 18) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + + # children are running + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + + # child 1 terminates (code 78) + # child 2 terminates (signal 5) + self.add_zombie(45, EXITCODE(78)) + self.add_zombie(46, SIGNAL(5)) + self.running = False + self.watcher._sig_chld() + + callback1.assert_called_once_with(45, 78, 17, 8) + callback2.assert_called_once_with(46, -5, 1147, 18) + + callback1.reset_mock() + callback2.reset_mock() + + # ensure that the children are effectively reaped + self.add_zombie(45, EXITCODE(14)) + self.add_zombie(46, EXITCODE(15)) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + + @waitpid_mocks + def test_sigchld_race_condition(self, m_waitpid): + # register a child + callback = mock.Mock() + + with self.watcher: + # child terminates before being registered + self.add_zombie(50, EXITCODE(4)) + self.watcher._sig_chld() + + self.watcher.add_child_handler(50, callback, 1, 12) + + callback.assert_called_once_with(50, 4, 1, 12) + callback.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(50, SIGNAL(1)) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + + @waitpid_mocks + def test_sigchld_replace_handler(self, m_waitpid): + callback1 = mock.Mock() + callback2 = mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(51, callback1, 19) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + + # register the same child again + with self.watcher: + self.watcher.add_child_handler(51, callback2, 21) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + + # child terminates (signal 8) + self.running = False + self.add_zombie(51, SIGNAL(8)) + self.watcher._sig_chld() + + callback2.assert_called_once_with(51, -8, 21) + self.assertFalse(callback1.called) + + callback2.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(51, EXITCODE(13)) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + + @waitpid_mocks + def test_sigchld_remove_handler(self, m_waitpid): + callback = mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(52, callback, 1984) + + self.assertFalse(callback.called) + + # unregister the child + self.watcher.remove_child_handler(52) + + self.assertFalse(callback.called) + + # child terminates (code 99) + self.running = False + self.add_zombie(52, EXITCODE(99)) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + + @waitpid_mocks + def test_sigchld_unknown_status(self, m_waitpid): + callback = mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(53, callback, -19) + + self.assertFalse(callback.called) + + # terminate with unknown status + self.zombies[53] = 1178 + self.running = False + self.watcher._sig_chld() + + callback.assert_called_once_with(53, 1178, -19) + + callback.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(53, EXITCODE(101)) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + + @waitpid_mocks + def test_remove_child_handler(self, m_waitpid): + callback1 = mock.Mock() + callback2 = mock.Mock() + callback3 = mock.Mock() + + # register children + with self.watcher: + self.running = True + self.watcher.add_child_handler(54, callback1, 1) + self.watcher.add_child_handler(55, callback2, 2) + self.watcher.add_child_handler(56, callback3, 3) + + # remove child handler 1 + self.assertTrue(self.watcher.remove_child_handler(54)) + + # remove child handler 2 multiple times + self.assertTrue(self.watcher.remove_child_handler(55)) + self.assertFalse(self.watcher.remove_child_handler(55)) + self.assertFalse(self.watcher.remove_child_handler(55)) + + # all children terminate + self.add_zombie(54, EXITCODE(0)) + self.add_zombie(55, EXITCODE(1)) + self.add_zombie(56, EXITCODE(2)) + self.running = False + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + callback3.assert_called_once_with(56, 2, 3) + + @waitpid_mocks + def test_sigchld_unhandled_exception(self, m_waitpid): + callback = mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(57, callback) + + # raise an exception + m_waitpid.side_effect = ValueError + + with mock.patch.object(log.logger, + 'error') as m_error: + + self.assertEqual(self.watcher._sig_chld(), None) + self.assertTrue(m_error.called) + + @waitpid_mocks + def test_sigchld_child_reaped_elsewhere(self, m_waitpid): + # register a child + callback = mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(58, callback) + + self.assertFalse(callback.called) + + # child terminates + self.running = False + self.add_zombie(58, EXITCODE(4)) + + # waitpid is called elsewhere + os.waitpid(58, os.WNOHANG) + + m_waitpid.reset_mock() + + # sigchld + with self.ignore_warnings: + self.watcher._sig_chld() + + if isinstance(self.watcher, asyncio.FastChildWatcher): + # here the FastChildWatcher enters a deadlock + # (there is no way to prevent it) + self.assertFalse(callback.called) + else: + callback.assert_called_once_with(58, 255) + + @waitpid_mocks + def test_sigchld_unknown_pid_during_registration(self, m_waitpid): + # register two children + callback1 = mock.Mock() + callback2 = mock.Mock() + + with self.ignore_warnings, self.watcher: + self.running = True + # child 1 terminates + self.add_zombie(591, EXITCODE(7)) + # an unknown child terminates + self.add_zombie(593, EXITCODE(17)) + + self.watcher._sig_chld() + + self.watcher.add_child_handler(591, callback1) + self.watcher.add_child_handler(592, callback2) + + callback1.assert_called_once_with(591, 7) + self.assertFalse(callback2.called) + + @waitpid_mocks + def test_set_loop(self, m_waitpid): + # register a child + callback = mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(60, callback) + + # attach a new loop + old_loop = self.loop + self.loop = self.new_test_loop() + patch = mock.patch.object + + with patch(old_loop, "remove_signal_handler") as m_old_remove, \ + patch(self.loop, "add_signal_handler") as m_new_add: + + self.watcher.attach_loop(self.loop) + + m_old_remove.assert_called_once_with( + signal.SIGCHLD) + m_new_add.assert_called_once_with( + signal.SIGCHLD, self.watcher._sig_chld) + + # child terminates + self.running = False + self.add_zombie(60, EXITCODE(9)) + self.watcher._sig_chld() + + callback.assert_called_once_with(60, 9) + + @waitpid_mocks + def test_set_loop_race_condition(self, m_waitpid): + # register 3 children + callback1 = mock.Mock() + callback2 = mock.Mock() + callback3 = mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(61, callback1) + self.watcher.add_child_handler(62, callback2) + self.watcher.add_child_handler(622, callback3) + + # detach the loop + old_loop = self.loop + self.loop = None + + with mock.patch.object( + old_loop, "remove_signal_handler") as m_remove_signal_handler: + + with self.assertWarnsRegex( + RuntimeWarning, 'A loop is being detached'): + self.watcher.attach_loop(None) + + m_remove_signal_handler.assert_called_once_with( + signal.SIGCHLD) + + # child 1 & 2 terminate + self.add_zombie(61, EXITCODE(11)) + self.add_zombie(62, SIGNAL(5)) + + # SIGCHLD was not caught + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(callback3.called) + + # attach a new loop + self.loop = self.new_test_loop() + + with mock.patch.object( + self.loop, "add_signal_handler") as m_add_signal_handler: + + self.watcher.attach_loop(self.loop) + + m_add_signal_handler.assert_called_once_with( + signal.SIGCHLD, self.watcher._sig_chld) + callback1.assert_called_once_with(61, 11) # race condition! + callback2.assert_called_once_with(62, -5) # race condition! + self.assertFalse(callback3.called) + + callback1.reset_mock() + callback2.reset_mock() + + # child 3 terminates + self.running = False + self.add_zombie(622, EXITCODE(19)) + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + callback3.assert_called_once_with(622, 19) + + @waitpid_mocks + def test_close(self, m_waitpid): + # register two children + callback1 = mock.Mock() + + with self.watcher: + self.running = True + # child 1 terminates + self.add_zombie(63, EXITCODE(9)) + # other child terminates + self.add_zombie(65, EXITCODE(18)) + self.watcher._sig_chld() + + self.watcher.add_child_handler(63, callback1) + self.watcher.add_child_handler(64, callback1) + + self.assertEqual(len(self.watcher._callbacks), 1) + if isinstance(self.watcher, asyncio.FastChildWatcher): + self.assertEqual(len(self.watcher._zombies), 1) + + with mock.patch.object( + self.loop, + "remove_signal_handler") as m_remove_signal_handler: + + self.watcher.close() + + m_remove_signal_handler.assert_called_once_with( + signal.SIGCHLD) + self.assertFalse(self.watcher._callbacks) + if isinstance(self.watcher, asyncio.FastChildWatcher): + self.assertFalse(self.watcher._zombies) + + +class SafeChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase): + def create_watcher(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + return asyncio.SafeChildWatcher() + + +class FastChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase): + def create_watcher(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + return asyncio.FastChildWatcher() + + +class PolicyTests(unittest.TestCase): + + def create_policy(self): + return asyncio.DefaultEventLoopPolicy() + + @mock.patch('asyncio.unix_events.can_use_pidfd') + def test_get_default_child_watcher(self, m_can_use_pidfd): + m_can_use_pidfd.return_value = False + policy = self.create_policy() + self.assertIsNone(policy._watcher) + with self.assertWarns(DeprecationWarning): + watcher = policy.get_child_watcher() + self.assertIsInstance(watcher, asyncio.ThreadedChildWatcher) + + self.assertIs(policy._watcher, watcher) + with self.assertWarns(DeprecationWarning): + self.assertIs(watcher, policy.get_child_watcher()) + + m_can_use_pidfd.return_value = True + policy = self.create_policy() + self.assertIsNone(policy._watcher) + with self.assertWarns(DeprecationWarning): + watcher = policy.get_child_watcher() + self.assertIsInstance(watcher, asyncio.PidfdChildWatcher) + + self.assertIs(policy._watcher, watcher) + with self.assertWarns(DeprecationWarning): + self.assertIs(watcher, policy.get_child_watcher()) + + def test_get_child_watcher_after_set(self): + policy = self.create_policy() + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + watcher = asyncio.FastChildWatcher() + policy.set_child_watcher(watcher) + + self.assertIs(policy._watcher, watcher) + with self.assertWarns(DeprecationWarning): + self.assertIs(watcher, policy.get_child_watcher()) + + def test_get_child_watcher_thread(self): + + def f(): + policy.set_event_loop(policy.new_event_loop()) + + self.assertIsInstance(policy.get_event_loop(), + asyncio.AbstractEventLoop) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + watcher = policy.get_child_watcher() + + self.assertIsInstance(watcher, asyncio.SafeChildWatcher) + self.assertIsNone(watcher._loop) + + policy.get_event_loop().close() + + policy = self.create_policy() + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + policy.set_child_watcher(asyncio.SafeChildWatcher()) + + th = threading.Thread(target=f) + th.start() + th.join() + + def test_child_watcher_replace_mainloop_existing(self): + policy = self.create_policy() + loop = policy.new_event_loop() + policy.set_event_loop(loop) + + # Explicitly setup SafeChildWatcher, + # default ThreadedChildWatcher has no _loop property + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + watcher = asyncio.SafeChildWatcher() + policy.set_child_watcher(watcher) + watcher.attach_loop(loop) + + self.assertIs(watcher._loop, loop) + + new_loop = policy.new_event_loop() + policy.set_event_loop(new_loop) + + self.assertIs(watcher._loop, new_loop) + + policy.set_event_loop(None) + + self.assertIs(watcher._loop, None) + + loop.close() + new_loop.close() + + +class TestFunctional(unittest.TestCase): + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + asyncio.set_event_loop(None) + + def test_add_reader_invalid_argument(self): + def assert_raises(): + return self.assertRaisesRegex(ValueError, r'Invalid file object') + + cb = lambda: None + + with assert_raises(): + self.loop.add_reader(object(), cb) + with assert_raises(): + self.loop.add_writer(object(), cb) + + with assert_raises(): + self.loop.remove_reader(object()) + with assert_raises(): + self.loop.remove_writer(object()) + + def test_add_reader_or_writer_transport_fd(self): + def assert_raises(): + return self.assertRaisesRegex( + RuntimeError, + r'File descriptor .* is used by transport') + + async def runner(): + tr, pr = await self.loop.create_connection( + lambda: asyncio.Protocol(), sock=rsock) + + try: + cb = lambda: None + + with assert_raises(): + self.loop.add_reader(rsock, cb) + with assert_raises(): + self.loop.add_reader(rsock.fileno(), cb) + + with assert_raises(): + self.loop.remove_reader(rsock) + with assert_raises(): + self.loop.remove_reader(rsock.fileno()) + + with assert_raises(): + self.loop.add_writer(rsock, cb) + with assert_raises(): + self.loop.add_writer(rsock.fileno(), cb) + + with assert_raises(): + self.loop.remove_writer(rsock) + with assert_raises(): + self.loop.remove_writer(rsock.fileno()) + + finally: + tr.close() + + rsock, wsock = socket.socketpair() + try: + self.loop.run_until_complete(runner()) + finally: + rsock.close() + wsock.close() + + +@support.requires_fork() +class TestFork(unittest.IsolatedAsyncioTestCase): + + async def test_fork_not_share_event_loop(self): + # The forked process should not share the event loop with the parent + loop = asyncio.get_running_loop() + r, w = os.pipe() + self.addCleanup(os.close, r) + self.addCleanup(os.close, w) + pid = os.fork() + if pid == 0: + # child + try: + with self.assertWarns(DeprecationWarning): + loop = asyncio.get_event_loop_policy().get_event_loop() + os.write(w, b'LOOP:' + str(id(loop)).encode()) + except RuntimeError: + os.write(w, b'NO LOOP') + except BaseException as e: + os.write(w, b'ERROR:' + ascii(e).encode()) + finally: + os._exit(0) + else: + # parent + result = os.read(r, 100) + self.assertEqual(result[:5], b'LOOP:', result) + self.assertNotEqual(int(result[5:]), id(loop)) + wait_process(pid, exitcode=0) + + @hashlib_helper.requires_hashdigest('md5') + @support.skip_if_sanitizer("TSAN doesn't support threads after fork", thread=True) + @unittest.skip('TODO: RUSTPYTHON') + # AttributeError: 'RLock' object has no attribute '_recursion_count' + def test_fork_signal_handling(self): + self.addCleanup(multiprocessing_cleanup_tests) + + # Sending signal to the forked process should not affect the parent + # process + ctx = multiprocessing.get_context('fork') + manager = ctx.Manager() + self.addCleanup(manager.shutdown) + child_started = manager.Event() + child_handled = manager.Event() + parent_handled = manager.Event() + + def child_main(): + def on_sigterm(*args): + child_handled.set() + sys.exit() + + signal.signal(signal.SIGTERM, on_sigterm) + child_started.set() + while True: + time.sleep(1) + + async def main(): + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, lambda *args: parent_handled.set()) + + process = ctx.Process(target=child_main) + process.start() + child_started.wait() + os.kill(process.pid, signal.SIGTERM) + process.join(timeout=support.SHORT_TIMEOUT) + + async def func(): + await asyncio.sleep(0.1) + return 42 + + # Test parent's loop is still functional + self.assertEqual(await asyncio.create_task(func()), 42) + + asyncio.run(main()) + + child_handled.wait(timeout=support.SHORT_TIMEOUT) + self.assertFalse(parent_handled.is_set()) + self.assertTrue(child_handled.is_set()) + + @hashlib_helper.requires_hashdigest('md5') + @support.skip_if_sanitizer("TSAN doesn't support threads after fork", thread=True) + @unittest.skip('TODO: RUSTPYTHON') + # AttributeError: 'RLock' object has no attribute '_recursion_count' + # AttributeError: module '_hashlib' has no attribute 'UnsupportedDigestmodError' + def test_fork_asyncio_run(self): + self.addCleanup(multiprocessing_cleanup_tests) + + ctx = multiprocessing.get_context('fork') + manager = ctx.Manager() + self.addCleanup(manager.shutdown) + result = manager.Value('i', 0) + + async def child_main(): + await asyncio.sleep(0.1) + result.value = 42 + + process = ctx.Process(target=lambda: asyncio.run(child_main())) + process.start() + process.join() + + self.assertEqual(result.value, 42) + + @hashlib_helper.requires_hashdigest('md5') + @support.skip_if_sanitizer("TSAN doesn't support threads after fork", thread=True) + @unittest.skip('TODO: RUSTPYTHON') + # AttributeError: 'RLock' object has no attribute '_recursion_count' + def test_fork_asyncio_subprocess(self): + self.addCleanup(multiprocessing_cleanup_tests) + + ctx = multiprocessing.get_context('fork') + manager = ctx.Manager() + self.addCleanup(manager.shutdown) + result = manager.Value('i', 1) + + async def child_main(): + proc = await asyncio.create_subprocess_exec(sys.executable, '-c', 'pass') + result.value = await proc.wait() + + process = ctx.Process(target=lambda: asyncio.run(child_main())) + process.start() + process.join() + + self.assertEqual(result.value, 0) + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_waitfor.py b/Lib/test/test_asyncio/test_waitfor.py new file mode 100644 index 00000000000..11a8eeeab37 --- /dev/null +++ b/Lib/test/test_asyncio/test_waitfor.py @@ -0,0 +1,353 @@ +import asyncio +import unittest +import time +from test import support + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +# The following value can be used as a very small timeout: +# it passes check "timeout > 0", but has almost +# no effect on the test performance +_EPSILON = 0.0001 + + +class SlowTask: + """ Task will run for this defined time, ignoring cancel requests """ + TASK_TIMEOUT = 0.2 + + def __init__(self): + self.exited = False + + async def run(self): + exitat = time.monotonic() + self.TASK_TIMEOUT + + while True: + tosleep = exitat - time.monotonic() + if tosleep <= 0: + break + + try: + await asyncio.sleep(tosleep) + except asyncio.CancelledError: + pass + + self.exited = True + + +class AsyncioWaitForTest(unittest.IsolatedAsyncioTestCase): + + async def test_asyncio_wait_for_cancelled(self): + t = SlowTask() + + waitfortask = asyncio.create_task( + asyncio.wait_for(t.run(), t.TASK_TIMEOUT * 2)) + await asyncio.sleep(0) + waitfortask.cancel() + await asyncio.wait({waitfortask}) + + self.assertTrue(t.exited) + + async def test_asyncio_wait_for_timeout(self): + t = SlowTask() + + try: + await asyncio.wait_for(t.run(), t.TASK_TIMEOUT / 2) + except asyncio.TimeoutError: + pass + + self.assertTrue(t.exited) + + async def test_wait_for_timeout_less_then_0_or_0_future_done(self): + loop = asyncio.get_running_loop() + + fut = loop.create_future() + fut.set_result('done') + + ret = await asyncio.wait_for(fut, 0) + + self.assertEqual(ret, 'done') + self.assertTrue(fut.done()) + + async def test_wait_for_timeout_less_then_0_or_0_coroutine_do_not_started(self): + foo_started = False + + async def foo(): + nonlocal foo_started + foo_started = True + + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for(foo(), 0) + + self.assertEqual(foo_started, False) + + async def test_wait_for_timeout_less_then_0_or_0(self): + loop = asyncio.get_running_loop() + + for timeout in [0, -1]: + with self.subTest(timeout=timeout): + foo_running = None + started = loop.create_future() + + async def foo(): + nonlocal foo_running + foo_running = True + started.set_result(None) + try: + await asyncio.sleep(10) + finally: + foo_running = False + return 'done' + + fut = asyncio.create_task(foo()) + await started + + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for(fut, timeout) + + self.assertTrue(fut.done()) + # it should have been cancelled due to the timeout + self.assertTrue(fut.cancelled()) + self.assertEqual(foo_running, False) + + async def test_wait_for(self): + foo_running = None + + async def foo(): + nonlocal foo_running + foo_running = True + try: + await asyncio.sleep(support.LONG_TIMEOUT) + finally: + foo_running = False + return 'done' + + fut = asyncio.create_task(foo()) + + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for(fut, 0.1) + self.assertTrue(fut.done()) + # it should have been cancelled due to the timeout + self.assertTrue(fut.cancelled()) + self.assertEqual(foo_running, False) + + async def test_wait_for_blocking(self): + async def coro(): + return 'done' + + res = await asyncio.wait_for(coro(), timeout=None) + self.assertEqual(res, 'done') + + async def test_wait_for_race_condition(self): + loop = asyncio.get_running_loop() + + fut = loop.create_future() + task = asyncio.wait_for(fut, timeout=0.2) + loop.call_soon(fut.set_result, "ok") + res = await task + self.assertEqual(res, "ok") + + async def test_wait_for_cancellation_race_condition(self): + async def inner(): + with self.assertRaises(asyncio.CancelledError): + await asyncio.sleep(1) + return 1 + + result = await asyncio.wait_for(inner(), timeout=.01) + self.assertEqual(result, 1) + + async def test_wait_for_waits_for_task_cancellation(self): + task_done = False + + async def inner(): + nonlocal task_done + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + await asyncio.sleep(_EPSILON) + raise + finally: + task_done = True + + inner_task = asyncio.create_task(inner()) + + with self.assertRaises(asyncio.TimeoutError) as cm: + await asyncio.wait_for(inner_task, timeout=_EPSILON) + + self.assertTrue(task_done) + chained = cm.exception.__context__ + self.assertEqual(type(chained), asyncio.CancelledError) + + async def test_wait_for_waits_for_task_cancellation_w_timeout_0(self): + task_done = False + + async def foo(): + async def inner(): + nonlocal task_done + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + await asyncio.sleep(_EPSILON) + raise + finally: + task_done = True + + inner_task = asyncio.create_task(inner()) + await asyncio.sleep(_EPSILON) + await asyncio.wait_for(inner_task, timeout=0) + + with self.assertRaises(asyncio.TimeoutError) as cm: + await foo() + + self.assertTrue(task_done) + chained = cm.exception.__context__ + self.assertEqual(type(chained), asyncio.CancelledError) + + async def test_wait_for_reraises_exception_during_cancellation(self): + class FooException(Exception): + pass + + async def foo(): + async def inner(): + try: + await asyncio.sleep(0.2) + finally: + raise FooException + + inner_task = asyncio.create_task(inner()) + + await asyncio.wait_for(inner_task, timeout=_EPSILON) + + with self.assertRaises(FooException): + await foo() + + async def _test_cancel_wait_for(self, timeout): + loop = asyncio.get_running_loop() + + async def blocking_coroutine(): + fut = loop.create_future() + # Block: fut result is never set + await fut + + task = asyncio.create_task(blocking_coroutine()) + + wait = asyncio.create_task(asyncio.wait_for(task, timeout)) + loop.call_soon(wait.cancel) + + with self.assertRaises(asyncio.CancelledError): + await wait + + # Python issue #23219: cancelling the wait must also cancel the task + self.assertTrue(task.cancelled()) + + async def test_cancel_blocking_wait_for(self): + await self._test_cancel_wait_for(None) + + async def test_cancel_wait_for(self): + await self._test_cancel_wait_for(60.0) + + async def test_wait_for_cancel_suppressed(self): + # GH-86296: Suppressing CancelledError is discouraged + # but if a task suppresses CancelledError and returns a value, + # `wait_for` should return the value instead of raising CancelledError. + # This is the same behavior as `asyncio.timeout`. + + async def return_42(): + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + return 42 + + res = await asyncio.wait_for(return_42(), timeout=0.1) + self.assertEqual(res, 42) + + + async def test_wait_for_issue86296(self): + # GH-86296: The task should get cancelled and not run to completion. + # inner completes in one cycle of the event loop so it + # completes before the task is cancelled. + + async def inner(): + return 'done' + + inner_task = asyncio.create_task(inner()) + reached_end = False + + async def wait_for_coro(): + await asyncio.wait_for(inner_task, timeout=100) + await asyncio.sleep(1) + nonlocal reached_end + reached_end = True + + task = asyncio.create_task(wait_for_coro()) + self.assertFalse(task.done()) + # Run the task + await asyncio.sleep(0) + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + self.assertTrue(inner_task.done()) + self.assertEqual(await inner_task, 'done') + self.assertFalse(reached_end) + + +class WaitForShieldTests(unittest.IsolatedAsyncioTestCase): + + async def test_zero_timeout(self): + # `asyncio.shield` creates a new task which wraps the passed in + # awaitable and shields it from cancellation so with timeout=0 + # the task returned by `asyncio.shield` aka shielded_task gets + # cancelled immediately and the task wrapped by it is scheduled + # to run. + + async def coro(): + await asyncio.sleep(0.01) + return 'done' + + task = asyncio.create_task(coro()) + with self.assertRaises(asyncio.TimeoutError): + shielded_task = asyncio.shield(task) + await asyncio.wait_for(shielded_task, timeout=0) + + # Task is running in background + self.assertFalse(task.done()) + self.assertFalse(task.cancelled()) + self.assertTrue(shielded_task.cancelled()) + + # Wait for the task to complete + await asyncio.sleep(0.1) + self.assertTrue(task.done()) + + + async def test_none_timeout(self): + # With timeout=None the timeout is disabled so it + # runs till completion. + async def coro(): + await asyncio.sleep(0.1) + return 'done' + + task = asyncio.create_task(coro()) + await asyncio.wait_for(asyncio.shield(task), timeout=None) + + self.assertTrue(task.done()) + self.assertEqual(await task, "done") + + async def test_shielded_timeout(self): + # shield prevents the task from being cancelled. + async def coro(): + await asyncio.sleep(0.1) + return 'done' + + task = asyncio.create_task(coro()) + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for(asyncio.shield(task), timeout=0.01) + + self.assertFalse(task.done()) + self.assertFalse(task.cancelled()) + self.assertEqual(await task, "done") + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_windows_events.py b/Lib/test/test_asyncio/test_windows_events.py new file mode 100644 index 00000000000..0c128c599ba --- /dev/null +++ b/Lib/test/test_asyncio/test_windows_events.py @@ -0,0 +1,359 @@ +import os +import signal +import socket +import sys +import time +import threading +import unittest +from unittest import mock + +if sys.platform != 'win32': + raise unittest.SkipTest('Windows only') + +import _overlapped +import _winapi + +import asyncio +from asyncio import windows_events +from test.test_asyncio import utils as test_utils + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class UpperProto(asyncio.Protocol): + def __init__(self): + self.buf = [] + + def connection_made(self, trans): + self.trans = trans + + def data_received(self, data): + self.buf.append(data) + if b'\n' in data: + self.trans.write(b''.join(self.buf).upper()) + self.trans.close() + + +class WindowsEventsTestCase(test_utils.TestCase): + def _unraisablehook(self, unraisable): + # Storing unraisable.object can resurrect an object which is being + # finalized. Storing unraisable.exc_value creates a reference cycle. + self._unraisable = unraisable + print(unraisable) + + def setUp(self): + self._prev_unraisablehook = sys.unraisablehook + self._unraisable = None + sys.unraisablehook = self._unraisablehook + + def tearDown(self): + sys.unraisablehook = self._prev_unraisablehook + self.assertIsNone(self._unraisable) + +class ProactorLoopCtrlC(WindowsEventsTestCase): + + def test_ctrl_c(self): + + def SIGINT_after_delay(): + time.sleep(0.1) + signal.raise_signal(signal.SIGINT) + + thread = threading.Thread(target=SIGINT_after_delay) + loop = asyncio.new_event_loop() + try: + # only start the loop once the event loop is running + loop.call_soon(thread.start) + loop.run_forever() + self.fail("should not fall through 'run_forever'") + except KeyboardInterrupt: + pass + finally: + self.close_loop(loop) + thread.join() + + +class ProactorMultithreading(WindowsEventsTestCase): + def test_run_from_nonmain_thread(self): + finished = False + + async def coro(): + await asyncio.sleep(0) + + def func(): + nonlocal finished + loop = asyncio.new_event_loop() + loop.run_until_complete(coro()) + # close() must not call signal.set_wakeup_fd() + loop.close() + finished = True + + thread = threading.Thread(target=func) + thread.start() + thread.join() + self.assertTrue(finished) + + +class ProactorTests(WindowsEventsTestCase): + + def setUp(self): + super().setUp() + self.loop = asyncio.ProactorEventLoop() + self.set_event_loop(self.loop) + + def test_close(self): + a, b = socket.socketpair() + trans = self.loop._make_socket_transport(a, asyncio.Protocol()) + f = asyncio.ensure_future(self.loop.sock_recv(b, 100), loop=self.loop) + trans.close() + self.loop.run_until_complete(f) + self.assertEqual(f.result(), b'') + b.close() + + def test_double_bind(self): + ADDRESS = r'\\.\pipe\test_double_bind-%s' % os.getpid() + server1 = windows_events.PipeServer(ADDRESS) + with self.assertRaises(PermissionError): + windows_events.PipeServer(ADDRESS) + server1.close() + + def test_pipe(self): + res = self.loop.run_until_complete(self._test_pipe()) + self.assertEqual(res, 'done') + + async def _test_pipe(self): + ADDRESS = r'\\.\pipe\_test_pipe-%s' % os.getpid() + + with self.assertRaises(FileNotFoundError): + await self.loop.create_pipe_connection( + asyncio.Protocol, ADDRESS) + + [server] = await self.loop.start_serving_pipe( + UpperProto, ADDRESS) + self.assertIsInstance(server, windows_events.PipeServer) + + clients = [] + for i in range(5): + stream_reader = asyncio.StreamReader(loop=self.loop) + protocol = asyncio.StreamReaderProtocol(stream_reader, + loop=self.loop) + trans, proto = await self.loop.create_pipe_connection( + lambda: protocol, ADDRESS) + self.assertIsInstance(trans, asyncio.Transport) + self.assertEqual(protocol, proto) + clients.append((stream_reader, trans)) + + for i, (r, w) in enumerate(clients): + w.write('lower-{}\n'.format(i).encode()) + + for i, (r, w) in enumerate(clients): + response = await r.readline() + self.assertEqual(response, 'LOWER-{}\n'.format(i).encode()) + w.close() + + server.close() + + with self.assertRaises(FileNotFoundError): + await self.loop.create_pipe_connection( + asyncio.Protocol, ADDRESS) + + return 'done' + + def test_connect_pipe_cancel(self): + exc = OSError() + exc.winerror = _overlapped.ERROR_PIPE_BUSY + with mock.patch.object(_overlapped, 'ConnectPipe', + side_effect=exc) as connect: + coro = self.loop._proactor.connect_pipe('pipe_address') + task = self.loop.create_task(coro) + + # check that it's possible to cancel connect_pipe() + task.cancel() + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(task) + + def test_wait_for_handle(self): + event = _overlapped.CreateEvent(None, True, False, None) + self.addCleanup(_winapi.CloseHandle, event) + + # Wait for unset event with 0.5s timeout; + # result should be False at timeout + timeout = 0.5 + fut = self.loop._proactor.wait_for_handle(event, timeout) + start = self.loop.time() + done = self.loop.run_until_complete(fut) + elapsed = self.loop.time() - start + + self.assertEqual(done, False) + self.assertFalse(fut.result()) + self.assertGreaterEqual(elapsed, timeout - test_utils.CLOCK_RES) + + _overlapped.SetEvent(event) + + # Wait for set event; + # result should be True immediately + fut = self.loop._proactor.wait_for_handle(event, 10) + done = self.loop.run_until_complete(fut) + + self.assertEqual(done, True) + self.assertTrue(fut.result()) + + # asyncio issue #195: cancelling a done _WaitHandleFuture + # must not crash + fut.cancel() + + def test_wait_for_handle_cancel(self): + event = _overlapped.CreateEvent(None, True, False, None) + self.addCleanup(_winapi.CloseHandle, event) + + # Wait for unset event with a cancelled future; + # CancelledError should be raised immediately + fut = self.loop._proactor.wait_for_handle(event, 10) + fut.cancel() + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(fut) + + # asyncio issue #195: cancelling a _WaitHandleFuture twice + # must not crash + fut = self.loop._proactor.wait_for_handle(event) + fut.cancel() + fut.cancel() + + def test_read_self_pipe_restart(self): + # Regression test for https://bugs.python.org/issue39010 + # Previously, restarting a proactor event loop in certain states + # would lead to spurious ConnectionResetErrors being logged. + self.loop.call_exception_handler = mock.Mock() + # Start an operation in another thread so that the self-pipe is used. + # This is theoretically timing-dependent (the task in the executor + # must complete before our start/stop cycles), but in practice it + # seems to work every time. + f = self.loop.run_in_executor(None, lambda: None) + self.loop.stop() + self.loop.run_forever() + self.loop.stop() + self.loop.run_forever() + + # Shut everything down cleanly. This is an important part of the + # test - in issue 39010, the error occurred during loop.close(), + # so we want to close the loop during the test instead of leaving + # it for tearDown. + # + # First wait for f to complete to avoid a "future's result was never + # retrieved" error. + self.loop.run_until_complete(f) + # Now shut down the loop itself (self.close_loop also shuts down the + # loop's default executor). + self.close_loop(self.loop) + self.assertFalse(self.loop.call_exception_handler.called) + + def test_address_argument_type_error(self): + # Regression test for https://github.com/python/cpython/issues/98793 + proactor = self.loop._proactor + sock = socket.socket(type=socket.SOCK_DGRAM) + bad_address = None + with self.assertRaises(TypeError): + proactor.connect(sock, bad_address) + with self.assertRaises(TypeError): + proactor.sendto(sock, b'abc', addr=bad_address) + sock.close() + + def test_client_pipe_stat(self): + res = self.loop.run_until_complete(self._test_client_pipe_stat()) + self.assertEqual(res, 'done') + + async def _test_client_pipe_stat(self): + # Regression test for https://github.com/python/cpython/issues/100573 + ADDRESS = r'\\.\pipe\test_client_pipe_stat-%s' % os.getpid() + + async def probe(): + # See https://github.com/python/cpython/pull/100959#discussion_r1068533658 + h = _overlapped.ConnectPipe(ADDRESS) + try: + _winapi.CloseHandle(_overlapped.ConnectPipe(ADDRESS)) + except OSError as e: + if e.winerror != _overlapped.ERROR_PIPE_BUSY: + raise + finally: + _winapi.CloseHandle(h) + + with self.assertRaises(FileNotFoundError): + await probe() + + [server] = await self.loop.start_serving_pipe(asyncio.Protocol, ADDRESS) + self.assertIsInstance(server, windows_events.PipeServer) + + errors = [] + self.loop.set_exception_handler(lambda _, data: errors.append(data)) + + for i in range(5): + await self.loop.create_task(probe()) + + self.assertEqual(len(errors), 0, errors) + + server.close() + + with self.assertRaises(FileNotFoundError): + await probe() + + return "done" + + def test_loop_restart(self): + # We're fishing for the "RuntimeError: <_overlapped.Overlapped object at XXX> + # still has pending operation at deallocation, the process may crash" error + stop = threading.Event() + def threadMain(): + while not stop.is_set(): + self.loop.call_soon_threadsafe(lambda: None) + time.sleep(0.01) + thr = threading.Thread(target=threadMain) + + # In 10 60-second runs of this test prior to the fix: + # time in seconds until failure: (none), 15.0, 6.4, (none), 7.6, 8.3, 1.7, 22.2, 23.5, 8.3 + # 10 seconds had a 50% failure rate but longer would be more costly + end_time = time.time() + 10 # Run for 10 seconds + self.loop.call_soon(thr.start) + while not self._unraisable: # Stop if we got an unraisable exc + self.loop.stop() + self.loop.run_forever() + if time.time() >= end_time: + break + + stop.set() + thr.join() + + +class WinPolicyTests(WindowsEventsTestCase): + + def test_selector_win_policy(self): + async def main(): + self.assertIsInstance( + asyncio.get_running_loop(), + asyncio.SelectorEventLoop) + + old_policy = asyncio.get_event_loop_policy() + try: + asyncio.set_event_loop_policy( + asyncio.WindowsSelectorEventLoopPolicy()) + asyncio.run(main()) + finally: + asyncio.set_event_loop_policy(old_policy) + + def test_proactor_win_policy(self): + async def main(): + self.assertIsInstance( + asyncio.get_running_loop(), + asyncio.ProactorEventLoop) + + old_policy = asyncio.get_event_loop_policy() + try: + asyncio.set_event_loop_policy( + asyncio.WindowsProactorEventLoopPolicy()) + asyncio.run(main()) + finally: + asyncio.set_event_loop_policy(old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_windows_utils.py b/Lib/test/test_asyncio/test_windows_utils.py new file mode 100644 index 00000000000..eafa5be3829 --- /dev/null +++ b/Lib/test/test_asyncio/test_windows_utils.py @@ -0,0 +1,133 @@ +"""Tests for window_utils""" + +import sys +import unittest +import warnings + +if sys.platform != 'win32': + raise unittest.SkipTest('Windows only') + +import _overlapped +import _winapi + +import asyncio +from asyncio import windows_utils +from test import support + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class PipeTests(unittest.TestCase): + + def test_pipe_overlapped(self): + h1, h2 = windows_utils.pipe(overlapped=(True, True)) + try: + ov1 = _overlapped.Overlapped() + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, 0) + + ov1.ReadFile(h1, 100) + self.assertTrue(ov1.pending) + self.assertEqual(ov1.error, _winapi.ERROR_IO_PENDING) + ERROR_IO_INCOMPLETE = 996 + try: + ov1.getresult() + except OSError as e: + self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE) + else: + raise RuntimeError('expected ERROR_IO_INCOMPLETE') + + ov2 = _overlapped.Overlapped() + self.assertFalse(ov2.pending) + self.assertEqual(ov2.error, 0) + + ov2.WriteFile(h2, b"hello") + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + + res = _winapi.WaitForMultipleObjects([ov2.event], False, 100) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE) + self.assertFalse(ov2.pending) + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + self.assertEqual(ov1.getresult(), b"hello") + finally: + _winapi.CloseHandle(h1) + _winapi.CloseHandle(h2) + + def test_pipe_handle(self): + h, _ = windows_utils.pipe(overlapped=(True, True)) + _winapi.CloseHandle(_) + p = windows_utils.PipeHandle(h) + self.assertEqual(p.fileno(), h) + self.assertEqual(p.handle, h) + + # check garbage collection of p closes handle + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "", ResourceWarning) + del p + support.gc_collect() + try: + _winapi.CloseHandle(h) + except OSError as e: + self.assertEqual(e.winerror, 6) # ERROR_INVALID_HANDLE + else: + raise RuntimeError('expected ERROR_INVALID_HANDLE') + + +class PopenTests(unittest.TestCase): + + def test_popen(self): + command = r"""if 1: + import sys + s = sys.stdin.readline() + sys.stdout.write(s.upper()) + sys.stderr.write('stderr') + """ + msg = b"blah\n" + + p = windows_utils.Popen([sys.executable, '-c', command], + stdin=windows_utils.PIPE, + stdout=windows_utils.PIPE, + stderr=windows_utils.PIPE) + + for f in [p.stdin, p.stdout, p.stderr]: + self.assertIsInstance(f, windows_utils.PipeHandle) + + ovin = _overlapped.Overlapped() + ovout = _overlapped.Overlapped() + overr = _overlapped.Overlapped() + + ovin.WriteFile(p.stdin.handle, msg) + ovout.ReadFile(p.stdout.handle, 100) + overr.ReadFile(p.stderr.handle, 100) + + events = [ovin.event, ovout.event, overr.event] + # Super-long timeout for slow buildbots. + res = _winapi.WaitForMultipleObjects(events, True, + int(support.SHORT_TIMEOUT * 1000)) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + self.assertFalse(ovout.pending) + self.assertFalse(overr.pending) + self.assertFalse(ovin.pending) + + self.assertEqual(ovin.getresult(), len(msg)) + out = ovout.getresult().rstrip() + err = overr.getresult().rstrip() + + self.assertGreater(len(out), 0) + self.assertGreater(len(err), 0) + # allow for partial reads... + self.assertTrue(msg.upper().rstrip().startswith(out)) + self.assertTrue(b"stderr".startswith(err)) + + # The context manager calls wait() and closes resources + with p: + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/utils.py b/Lib/test/test_asyncio/utils.py new file mode 100644 index 00000000000..d9a2939be13 --- /dev/null +++ b/Lib/test/test_asyncio/utils.py @@ -0,0 +1,638 @@ +"""Utilities shared by tests.""" + +import asyncio +import collections +import contextlib +import io +import logging +import os +import re +import selectors +import socket +import socketserver +import sys +import threading +import unittest +import weakref +import warnings +from unittest import mock + +from http.server import HTTPServer +from wsgiref.simple_server import WSGIRequestHandler, WSGIServer + +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from asyncio import base_events +from asyncio import events +from asyncio import format_helpers +from asyncio import tasks +from asyncio.log import logger +from test import support +from test.support import socket_helper +from test.support import threading_helper + + +# Use the maximum known clock resolution (gh-75191, gh-110088): Windows +# GetTickCount64() has a resolution of 15.6 ms. Use 50 ms to tolerate rounding +# issues. +CLOCK_RES = 0.050 + + +def data_file(*filename): + fullname = os.path.join(support.TEST_HOME_DIR, *filename) + if os.path.isfile(fullname): + return fullname + fullname = os.path.join(os.path.dirname(__file__), '..', *filename) + if os.path.isfile(fullname): + return fullname + raise FileNotFoundError(os.path.join(filename)) + + +ONLYCERT = data_file('certdata', 'ssl_cert.pem') +ONLYKEY = data_file('certdata', 'ssl_key.pem') +SIGNED_CERTFILE = data_file('certdata', 'keycert3.pem') +SIGNING_CA = data_file('certdata', 'pycacert.pem') +PEERCERT = { + 'OCSP': ('http://testca.pythontest.net/testca/ocsp/',), + 'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',), + 'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',), + 'issuer': ((('countryName', 'XY'),), + (('organizationName', 'Python Software Foundation CA'),), + (('commonName', 'our-ca-server'),)), + 'notAfter': 'Oct 28 14:23:16 2037 GMT', + 'notBefore': 'Aug 29 14:23:16 2018 GMT', + 'serialNumber': 'CB2D80995A69525C', + 'subject': ((('countryName', 'XY'),), + (('localityName', 'Castle Anthrax'),), + (('organizationName', 'Python Software Foundation'),), + (('commonName', 'localhost'),)), + 'subjectAltName': (('DNS', 'localhost'),), + 'version': 3 +} + + +def simple_server_sslcontext(): + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_context.load_cert_chain(ONLYCERT, ONLYKEY) + server_context.check_hostname = False + server_context.verify_mode = ssl.CERT_NONE + return server_context + + +def simple_client_sslcontext(*, disable_verify=True): + client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + client_context.check_hostname = False + if disable_verify: + client_context.verify_mode = ssl.CERT_NONE + return client_context + + +def dummy_ssl_context(): + if ssl is None: + return None + else: + return simple_client_sslcontext(disable_verify=True) + + +def run_briefly(loop): + async def once(): + pass + gen = once() + t = loop.create_task(gen) + # Don't log a warning if the task is not done after run_until_complete(). + # It occurs if the loop is stopped or if a task raises a BaseException. + t._log_destroy_pending = False + try: + loop.run_until_complete(t) + finally: + gen.close() + + +def run_until(loop, pred, timeout=support.SHORT_TIMEOUT): + delay = 0.001 + for _ in support.busy_retry(timeout, error=False): + if pred(): + break + loop.run_until_complete(tasks.sleep(delay)) + delay = max(delay * 2, 1.0) + else: + raise TimeoutError() + + +def run_once(loop): + """Legacy API to run once through the event loop. + + This is the recommended pattern for test code. It will poll the + selector once and run all callbacks scheduled in response to I/O + events. + """ + loop.call_soon(loop.stop) + loop.run_forever() + + +class SilentWSGIRequestHandler(WSGIRequestHandler): + + def get_stderr(self): + return io.StringIO() + + def log_message(self, format, *args): + pass + + +class SilentWSGIServer(WSGIServer): + + request_timeout = support.LOOPBACK_TIMEOUT + + def get_request(self): + request, client_addr = super().get_request() + request.settimeout(self.request_timeout) + return request, client_addr + + def handle_error(self, request, client_address): + pass + + +class SSLWSGIServerMixin: + + def finish_request(self, request, client_address): + # The relative location of our test directory (which + # contains the ssl key and certificate files) differs + # between the stdlib and stand-alone asyncio. + # Prefer our own if we can find it. + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + context.load_cert_chain(ONLYCERT, ONLYKEY) + + ssock = context.wrap_socket(request, server_side=True) + try: + self.RequestHandlerClass(ssock, client_address, self) + ssock.close() + except OSError: + # maybe socket has been closed by peer + pass + + +class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): + pass + + +def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): + + def loop(environ): + size = int(environ['CONTENT_LENGTH']) + while size: + data = environ['wsgi.input'].read(min(size, 0x10000)) + yield data + size -= len(data) + + def app(environ, start_response): + status = '200 OK' + headers = [('Content-type', 'text/plain')] + start_response(status, headers) + if environ['PATH_INFO'] == '/loop': + return loop(environ) + else: + return [b'Test message'] + + # Run the test WSGI server in a separate thread in order not to + # interfere with event handling in the main thread + server_class = server_ssl_cls if use_ssl else server_cls + httpd = server_class(address, SilentWSGIRequestHandler) + httpd.set_app(app) + httpd.address = httpd.server_address + server_thread = threading.Thread( + target=lambda: httpd.serve_forever(poll_interval=0.05)) + server_thread.start() + try: + yield httpd + finally: + httpd.shutdown() + httpd.server_close() + server_thread.join() + + +if hasattr(socket, 'AF_UNIX'): + + class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): + + def server_bind(self): + socketserver.UnixStreamServer.server_bind(self) + self.server_name = '127.0.0.1' + self.server_port = 80 + + + class UnixWSGIServer(UnixHTTPServer, WSGIServer): + + request_timeout = support.LOOPBACK_TIMEOUT + + def server_bind(self): + UnixHTTPServer.server_bind(self) + self.setup_environ() + + def get_request(self): + request, client_addr = super().get_request() + request.settimeout(self.request_timeout) + # Code in the stdlib expects that get_request + # will return a socket and a tuple (host, port). + # However, this isn't true for UNIX sockets, + # as the second return value will be a path; + # hence we return some fake data sufficient + # to get the tests going + return request, ('127.0.0.1', '') + + + class SilentUnixWSGIServer(UnixWSGIServer): + + def handle_error(self, request, client_address): + pass + + + class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer): + pass + + + def gen_unix_socket_path(): + return socket_helper.create_unix_domain_name() + + + @contextlib.contextmanager + def unix_socket_path(): + path = gen_unix_socket_path() + try: + yield path + finally: + try: + os.unlink(path) + except OSError: + pass + + + @contextlib.contextmanager + def run_test_unix_server(*, use_ssl=False): + with unix_socket_path() as path: + yield from _run_test_server(address=path, use_ssl=use_ssl, + server_cls=SilentUnixWSGIServer, + server_ssl_cls=UnixSSLWSGIServer) + + +@contextlib.contextmanager +def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): + yield from _run_test_server(address=(host, port), use_ssl=use_ssl, + server_cls=SilentWSGIServer, + server_ssl_cls=SSLWSGIServer) + + +def echo_datagrams(sock): + while True: + data, addr = sock.recvfrom(4096) + if data == b'STOP': + sock.close() + break + else: + sock.sendto(data, addr) + + +@contextlib.contextmanager +def run_udp_echo_server(*, host='127.0.0.1', port=0): + addr_info = socket.getaddrinfo(host, port, type=socket.SOCK_DGRAM) + family, type, proto, _, sockaddr = addr_info[0] + sock = socket.socket(family, type, proto) + sock.bind((host, port)) + sockname = sock.getsockname() + thread = threading.Thread(target=lambda: echo_datagrams(sock)) + thread.start() + try: + yield sockname + finally: + # gh-122187: use a separate socket to send the stop message to avoid + # TSan reported race on the same socket. + sock2 = socket.socket(family, type, proto) + sock2.sendto(b'STOP', sockname) + sock2.close() + thread.join() + + +def make_test_protocol(base): + dct = {} + for name in dir(base): + if name.startswith('__') and name.endswith('__'): + # skip magic names + continue + dct[name] = MockCallback(return_value=None) + return type('TestProtocol', (base,) + base.__bases__, dct)() + + +class TestSelector(selectors.BaseSelector): + + def __init__(self): + self.keys = {} + + def register(self, fileobj, events, data=None): + key = selectors.SelectorKey(fileobj, 0, events, data) + self.keys[fileobj] = key + return key + + def unregister(self, fileobj): + return self.keys.pop(fileobj) + + def select(self, timeout): + return [] + + def get_map(self): + return self.keys + + +class TestLoop(base_events.BaseEventLoop): + """Loop for unittests. + + It manages self time directly. + If something scheduled to be executed later then + on next loop iteration after all ready handlers done + generator passed to __init__ is calling. + + Generator should be like this: + + def gen(): + ... + when = yield ... + ... = yield time_advance + + Value returned by yield is absolute time of next scheduled handler. + Value passed to yield is time advance to move loop's time forward. + """ + + def __init__(self, gen=None): + super().__init__() + + if gen is None: + def gen(): + yield + self._check_on_close = False + else: + self._check_on_close = True + + self._gen = gen() + next(self._gen) + self._time = 0 + self._clock_resolution = 1e-9 + self._timers = [] + self._selector = TestSelector() + + self.readers = {} + self.writers = {} + self.reset_counters() + + self._transports = weakref.WeakValueDictionary() + + def time(self): + return self._time + + def advance_time(self, advance): + """Move test time forward.""" + if advance: + self._time += advance + + def close(self): + super().close() + if self._check_on_close: + try: + self._gen.send(0) + except StopIteration: + pass + else: # pragma: no cover + raise AssertionError("Time generator is not finished") + + def _add_reader(self, fd, callback, *args): + self.readers[fd] = events.Handle(callback, args, self, None) + + def _remove_reader(self, fd): + self.remove_reader_count[fd] += 1 + if fd in self.readers: + del self.readers[fd] + return True + else: + return False + + def assert_reader(self, fd, callback, *args): + if fd not in self.readers: + raise AssertionError(f'fd {fd} is not registered') + handle = self.readers[fd] + if handle._callback != callback: + raise AssertionError( + f'unexpected callback: {handle._callback} != {callback}') + if handle._args != args: + raise AssertionError( + f'unexpected callback args: {handle._args} != {args}') + + def assert_no_reader(self, fd): + if fd in self.readers: + raise AssertionError(f'fd {fd} is registered') + + def _add_writer(self, fd, callback, *args): + self.writers[fd] = events.Handle(callback, args, self, None) + + def _remove_writer(self, fd): + self.remove_writer_count[fd] += 1 + if fd in self.writers: + del self.writers[fd] + return True + else: + return False + + def assert_writer(self, fd, callback, *args): + if fd not in self.writers: + raise AssertionError(f'fd {fd} is not registered') + handle = self.writers[fd] + if handle._callback != callback: + raise AssertionError(f'{handle._callback!r} != {callback!r}') + if handle._args != args: + raise AssertionError(f'{handle._args!r} != {args!r}') + + def _ensure_fd_no_transport(self, fd): + if not isinstance(fd, int): + try: + fd = int(fd.fileno()) + except (AttributeError, TypeError, ValueError): + # This code matches selectors._fileobj_to_fd function. + raise ValueError("Invalid file object: " + "{!r}".format(fd)) from None + try: + transport = self._transports[fd] + except KeyError: + pass + else: + raise RuntimeError( + 'File descriptor {!r} is used by transport {!r}'.format( + fd, transport)) + + def add_reader(self, fd, callback, *args): + """Add a reader callback.""" + self._ensure_fd_no_transport(fd) + return self._add_reader(fd, callback, *args) + + def remove_reader(self, fd): + """Remove a reader callback.""" + self._ensure_fd_no_transport(fd) + return self._remove_reader(fd) + + def add_writer(self, fd, callback, *args): + """Add a writer callback..""" + self._ensure_fd_no_transport(fd) + return self._add_writer(fd, callback, *args) + + def remove_writer(self, fd): + """Remove a writer callback.""" + self._ensure_fd_no_transport(fd) + return self._remove_writer(fd) + + def reset_counters(self): + self.remove_reader_count = collections.defaultdict(int) + self.remove_writer_count = collections.defaultdict(int) + + def _run_once(self): + super()._run_once() + for when in self._timers: + advance = self._gen.send(when) + self.advance_time(advance) + self._timers = [] + + def call_at(self, when, callback, *args, context=None): + self._timers.append(when) + return super().call_at(when, callback, *args, context=context) + + def _process_events(self, event_list): + return + + def _write_to_self(self): + pass + + +def MockCallback(**kwargs): + return mock.Mock(spec=['__call__'], **kwargs) + + +class MockPattern(str): + """A regex based str with a fuzzy __eq__. + + Use this helper with 'mock.assert_called_with', or anywhere + where a regex comparison between strings is needed. + + For instance: + mock_call.assert_called_with(MockPattern('spam.*ham')) + """ + def __eq__(self, other): + return bool(re.search(str(self), other, re.S)) + + +class MockInstanceOf: + def __init__(self, type): + self._type = type + + def __eq__(self, other): + return isinstance(other, self._type) + + +def get_function_source(func): + source = format_helpers._get_function_source(func) + if source is None: + raise ValueError("unable to get the source of %r" % (func,)) + return source + + +class TestCase(unittest.TestCase): + @staticmethod + def close_loop(loop): + if loop._default_executor is not None: + if not loop.is_closed(): + loop.run_until_complete(loop.shutdown_default_executor()) + else: + loop._default_executor.shutdown(wait=True) + loop.close() + + policy = support.maybe_get_event_loop_policy() + if policy is not None: + try: + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + watcher = policy.get_child_watcher() + except NotImplementedError: + # watcher is not implemented by EventLoopPolicy, e.g. Windows + pass + else: + if isinstance(watcher, asyncio.ThreadedChildWatcher): + # Wait for subprocess to finish, but not forever + for thread in list(watcher._threads.values()): + thread.join(timeout=support.SHORT_TIMEOUT) + if thread.is_alive(): + raise RuntimeError(f"thread {thread} still alive: " + "subprocess still running") + + + def set_event_loop(self, loop, *, cleanup=True): + if loop is None: + raise AssertionError('loop is None') + # ensure that the event loop is passed explicitly in asyncio + events.set_event_loop(None) + if cleanup: + self.addCleanup(self.close_loop, loop) + + def new_test_loop(self, gen=None): + loop = TestLoop(gen) + self.set_event_loop(loop) + return loop + + def setUp(self): + self._thread_cleanup = threading_helper.threading_setup() + + def tearDown(self): + events.set_event_loop(None) + + # Detect CPython bug #23353: ensure that yield/yield-from is not used + # in an except block of a generator + self.assertIsNone(sys.exception()) + + self.doCleanups() + threading_helper.threading_cleanup(*self._thread_cleanup) + support.reap_children() + + +@contextlib.contextmanager +def disable_logger(): + """Context manager to disable asyncio logger. + + For example, it can be used to ignore warnings in debug mode. + """ + old_level = logger.level + try: + logger.setLevel(logging.CRITICAL+1) + yield + finally: + logger.setLevel(old_level) + + +def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM, + family=socket.AF_INET): + """Create a mock of a non-blocking socket.""" + sock = mock.MagicMock(socket.socket) + sock.proto = proto + sock.type = type + sock.family = family + sock.gettimeout.return_value = 0.0 + return sock + + +async def await_without_task(coro): + exc = None + def func(): + try: + for _ in coro.__await__(): + pass + except BaseException as err: + nonlocal exc + exc = err + asyncio.get_running_loop().call_soon(func) + await asyncio.sleep(0) + if exc is not None: + raise exc diff --git a/crates/codegen/src/compile.rs b/crates/codegen/src/compile.rs index 7909e924251..c44b2b00684 100644 --- a/crates/codegen/src/compile.rs +++ b/crates/codegen/src/compile.rs @@ -16,6 +16,7 @@ use crate::{ symboltable::{self, CompilerScope, SymbolFlags, SymbolScope, SymbolTable}, unparse::UnparseExpr, }; +use alloc::borrow::Cow; use itertools::Itertools; use malachite_bigint::BigInt; use num_complex::Complex; @@ -42,7 +43,7 @@ use rustpython_compiler_core::{ }, }; use rustpython_wtf8::Wtf8Buf; -use std::{borrow::Cow, collections::HashSet}; +use std::collections::HashSet; const MAXBLOCKS: usize = 20; @@ -293,7 +294,7 @@ fn compiler_unwrap_option(zelf: &Compiler, o: Option) -> T { o.unwrap() } -// fn compiler_result_unwrap(zelf: &Compiler, result: Result) -> T { +// fn compiler_result_unwrap(zelf: &Compiler, result: Result) -> T { // if result.is_err() { // eprintln!("=== CODEGEN PANIC INFO ==="); // eprintln!("This IS an internal error, an result was unwrapped during codegen"); @@ -1831,7 +1832,7 @@ impl Compiler { name.to_owned(), ); - let args_iter = std::iter::empty() + let args_iter = core::iter::empty() .chain(¶meters.posonlyargs) .chain(¶meters.args) .map(|arg| &arg.parameter) @@ -2438,7 +2439,7 @@ impl Compiler { let mut funcflags = bytecode::MakeFunctionFlags::empty(); // Handle positional defaults - let defaults: Vec<_> = std::iter::empty() + let defaults: Vec<_> = core::iter::empty() .chain(¶meters.posonlyargs) .chain(¶meters.args) .filter_map(|x| x.default.as_deref()) @@ -2566,7 +2567,7 @@ impl Compiler { let mut num_annotations = 0; // Handle parameter annotations - let parameters_iter = std::iter::empty() + let parameters_iter = core::iter::empty() .chain(¶meters.posonlyargs) .chain(¶meters.args) .chain(¶meters.kwonlyargs) @@ -4965,7 +4966,7 @@ impl Compiler { let name = "".to_owned(); // Prepare defaults before entering function - let defaults: Vec<_> = std::iter::empty() + let defaults: Vec<_> = core::iter::empty() .chain(¶ms.posonlyargs) .chain(¶ms.args) .filter_map(|x| x.default.as_deref()) diff --git a/crates/codegen/src/error.rs b/crates/codegen/src/error.rs index 70e2f13f253..459ba8e33b5 100644 --- a/crates/codegen/src/error.rs +++ b/crates/codegen/src/error.rs @@ -1,5 +1,6 @@ +use alloc::fmt; +use core::fmt::Display; use rustpython_compiler_core::SourceLocation; -use std::fmt::{self, Display}; use thiserror::Error; #[derive(Debug)] @@ -93,7 +94,7 @@ pub enum CodegenErrorType { NotImplementedYet, // RustPython marker for unimplemented features } -impl std::error::Error for CodegenErrorType {} +impl core::error::Error for CodegenErrorType {} impl fmt::Display for CodegenErrorType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { diff --git a/crates/codegen/src/ir.rs b/crates/codegen/src/ir.rs index de0126f1122..670635fbd37 100644 --- a/crates/codegen/src/ir.rs +++ b/crates/codegen/src/ir.rs @@ -1,4 +1,4 @@ -use std::ops; +use core::ops; use crate::{IndexMap, IndexSet, error::InternalError}; use rustpython_compiler_core::{ @@ -198,7 +198,7 @@ impl CodeInfo { *arg = new_arg; } let (extras, lo_arg) = arg.split(); - locations.extend(std::iter::repeat_n(info.location, arg.instr_size())); + locations.extend(core::iter::repeat_n(info.location, arg.instr_size())); instructions.extend( extras .map(|byte| CodeUnit::new(Instruction::ExtendedArg, byte)) @@ -401,7 +401,7 @@ fn stackdepth_push( fn iter_blocks(blocks: &[Block]) -> impl Iterator + '_ { let mut next = BlockIdx(0); - std::iter::from_fn(move || { + core::iter::from_fn(move || { if next == BlockIdx::NULL { return None; } diff --git a/crates/codegen/src/lib.rs b/crates/codegen/src/lib.rs index 291b57d7f67..34d3870ae91 100644 --- a/crates/codegen/src/lib.rs +++ b/crates/codegen/src/lib.rs @@ -5,6 +5,8 @@ #[macro_use] extern crate log; +extern crate alloc; + type IndexMap = indexmap::IndexMap; type IndexSet = indexmap::IndexSet; diff --git a/crates/codegen/src/string_parser.rs b/crates/codegen/src/string_parser.rs index ede2f118c37..175e75c1a26 100644 --- a/crates/codegen/src/string_parser.rs +++ b/crates/codegen/src/string_parser.rs @@ -5,7 +5,7 @@ //! after ruff has already successfully parsed the string literal, meaning //! we don't need to do any validation or error handling. -use std::convert::Infallible; +use core::convert::Infallible; use ruff_python_ast::{AnyStringFlags, StringFlags}; use rustpython_wtf8::{CodePoint, Wtf8, Wtf8Buf}; @@ -96,7 +96,7 @@ impl StringParser { } // OK because radix_bytes is always going to be in the ASCII range. - let radix_str = std::str::from_utf8(&radix_bytes[..len]).expect("ASCII bytes"); + let radix_str = core::str::from_utf8(&radix_bytes[..len]).expect("ASCII bytes"); let value = u32::from_str_radix(radix_str, 8).unwrap(); char::from_u32(value).unwrap() } diff --git a/crates/codegen/src/symboltable.rs b/crates/codegen/src/symboltable.rs index 3c8454b9e22..1629e5fff38 100644 --- a/crates/codegen/src/symboltable.rs +++ b/crates/codegen/src/symboltable.rs @@ -11,6 +11,7 @@ use crate::{ IndexMap, error::{CodegenError, CodegenErrorType}, }; +use alloc::{borrow::Cow, fmt}; use bitflags::bitflags; use ruff_python_ast::{ self as ast, Comprehension, Decorator, Expr, Identifier, ModExpression, ModModule, Parameter, @@ -20,7 +21,6 @@ use ruff_python_ast::{ }; use ruff_text_size::{Ranged, TextRange}; use rustpython_compiler_core::{PositionEncoding, SourceFile, SourceLocation}; -use std::{borrow::Cow, fmt}; /// Captures all symbols in the current scope, and has a list of sub-scopes in this scope. #[derive(Clone)] @@ -215,8 +215,8 @@ impl SymbolTableError { type SymbolTableResult = Result; -impl std::fmt::Debug for SymbolTable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for SymbolTable { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!( f, "SymbolTable({:?} symbols, {:?} sub scopes)", @@ -261,8 +261,8 @@ fn drop_class_free(symbol_table: &mut SymbolTable) { type SymbolMap = IndexMap; mod stack { + use core::ptr::NonNull; use std::panic; - use std::ptr::NonNull; pub struct StackStack { v: Vec>, } @@ -325,7 +325,7 @@ struct SymbolTableAnalyzer { impl SymbolTableAnalyzer { fn analyze_symbol_table(&mut self, symbol_table: &mut SymbolTable) -> SymbolTableResult { - let symbols = std::mem::take(&mut symbol_table.symbols); + let symbols = core::mem::take(&mut symbol_table.symbols); let sub_tables = &mut *symbol_table.sub_tables; let mut info = (symbols, symbol_table.typ); @@ -689,7 +689,7 @@ impl SymbolTableBuilder { fn leave_scope(&mut self) { let mut table = self.tables.pop().unwrap(); // Save the collected varnames to the symbol table - table.varnames = std::mem::take(&mut self.current_varnames); + table.varnames = core::mem::take(&mut self.current_varnames); self.tables.last_mut().unwrap().sub_tables.push(table); } diff --git a/crates/codegen/src/unparse.rs b/crates/codegen/src/unparse.rs index 74e35fd5e2a..7b26d229187 100644 --- a/crates/codegen/src/unparse.rs +++ b/crates/codegen/src/unparse.rs @@ -1,3 +1,5 @@ +use alloc::fmt; +use core::fmt::Display as _; use ruff_python_ast::{ self as ruff, Arguments, BoolOp, Comprehension, ConversionFlag, Expr, Identifier, Operator, Parameter, ParameterWithDefault, Parameters, @@ -5,7 +7,6 @@ use ruff_python_ast::{ use ruff_text_size::Ranged; use rustpython_compiler_core::SourceFile; use rustpython_literal::escape::{AsciiEscape, UnicodeEscape}; -use std::fmt::{self, Display as _}; mod precedence { macro_rules! precedence { @@ -51,7 +52,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { } fn p_delim(&mut self, first: &mut bool, s: &str) -> fmt::Result { - self.p_if(!std::mem::take(first), s) + self.p_if(!core::mem::take(first), s) } fn write_fmt(&mut self, f: fmt::Arguments<'_>) -> fmt::Result { @@ -575,7 +576,7 @@ impl<'a, 'b, 'c> Unparser<'a, 'b, 'c> { if conversion != ConversionFlag::None { self.p("!")?; let buf = &[conversion as u8]; - let c = std::str::from_utf8(buf).unwrap(); + let c = core::str::from_utf8(buf).unwrap(); self.p(c)?; } @@ -650,7 +651,7 @@ impl fmt::Display for UnparseExpr<'_> { } fn to_string_fmt(f: impl FnOnce(&mut fmt::Formatter<'_>) -> fmt::Result) -> String { - use std::cell::Cell; + use core::cell::Cell; struct Fmt(Cell>); impl) -> fmt::Result> fmt::Display for Fmt { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { diff --git a/crates/common/src/borrow.rs b/crates/common/src/borrow.rs index 610084006e1..d8389479b33 100644 --- a/crates/common/src/borrow.rs +++ b/crates/common/src/borrow.rs @@ -2,10 +2,8 @@ use crate::lock::{ MapImmutable, PyImmutableMappedMutexGuard, PyMappedMutexGuard, PyMappedRwLockReadGuard, PyMappedRwLockWriteGuard, PyMutexGuard, PyRwLockReadGuard, PyRwLockWriteGuard, }; -use std::{ - fmt, - ops::{Deref, DerefMut}, -}; +use alloc::fmt; +use core::ops::{Deref, DerefMut}; macro_rules! impl_from { ($lt:lifetime, $gen:ident, $t:ty, $($var:ident($from:ty),)*) => { diff --git a/crates/common/src/boxvec.rs b/crates/common/src/boxvec.rs index 8687ba7f7f5..3260e76ca87 100644 --- a/crates/common/src/boxvec.rs +++ b/crates/common/src/boxvec.rs @@ -2,13 +2,13 @@ //! An unresizable vector backed by a `Box<[T]>` #![allow(clippy::needless_lifetimes)] - -use std::{ +use alloc::{fmt, slice}; +use core::{ borrow::{Borrow, BorrowMut}, - cmp, fmt, + cmp, mem::{self, MaybeUninit}, ops::{Bound, Deref, DerefMut, RangeBounds}, - ptr, slice, + ptr, }; pub struct BoxVec { @@ -555,7 +555,7 @@ impl Extend for BoxVec { }; let mut iter = iter.into_iter(); loop { - if std::ptr::eq(ptr, end_ptr) { + if core::ptr::eq(ptr, end_ptr) { break; } if let Some(elt) = iter.next() { @@ -693,7 +693,7 @@ impl CapacityError { const CAPERROR: &str = "insufficient capacity"; -impl std::error::Error for CapacityError {} +impl core::error::Error for CapacityError {} impl fmt::Display for CapacityError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { diff --git a/crates/common/src/cformat.rs b/crates/common/src/cformat.rs index b553f0b6b10..24332396fdb 100644 --- a/crates/common/src/cformat.rs +++ b/crates/common/src/cformat.rs @@ -1,15 +1,16 @@ //! Implementation of Printf-Style string formatting //! as per the [Python Docs](https://docs.python.org/3/library/stdtypes.html#printf-style-string-formatting). +use alloc::fmt; use bitflags::bitflags; +use core::{ + cmp, + iter::{Enumerate, Peekable}, + str::FromStr, +}; use itertools::Itertools; use malachite_bigint::{BigInt, Sign}; use num_traits::Signed; use rustpython_literal::{float, format::Case}; -use std::{ - cmp, fmt, - iter::{Enumerate, Peekable}, - str::FromStr, -}; use crate::wtf8::{CodePoint, Wtf8, Wtf8Buf}; @@ -785,7 +786,7 @@ impl CFormatStrOrBytes { if !literal.is_empty() { parts.push(( part_index, - CFormatPart::Literal(std::mem::take(&mut literal)), + CFormatPart::Literal(core::mem::take(&mut literal)), )); } let spec = CFormatSpecKeyed::parse(iter).map_err(|err| CFormatError { @@ -816,7 +817,7 @@ impl CFormatStrOrBytes { impl IntoIterator for CFormatStrOrBytes { type Item = (usize, CFormatPart); - type IntoIter = std::vec::IntoIter; + type IntoIter = alloc::vec::IntoIter; fn into_iter(self) -> Self::IntoIter { self.parts.into_iter() diff --git a/crates/common/src/crt_fd.rs b/crates/common/src/crt_fd.rs index b873ef9c52c..1902a362e32 100644 --- a/crates/common/src/crt_fd.rs +++ b/crates/common/src/crt_fd.rs @@ -1,7 +1,9 @@ //! A module implementing an io type backed by the C runtime's file descriptors, i.e. what's //! returned from libc::open, even on windows. -use std::{cmp, ffi, fmt, io}; +use alloc::fmt; +use core::cmp; +use std::{ffi, io}; #[cfg(not(windows))] use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, OwnedFd, RawFd}; diff --git a/crates/common/src/encodings.rs b/crates/common/src/encodings.rs index 39ca2661262..d54581eb9ea 100644 --- a/crates/common/src/encodings.rs +++ b/crates/common/src/encodings.rs @@ -1,4 +1,4 @@ -use std::ops::{self, Range}; +use core::ops::{self, Range}; use num_traits::ToPrimitive; @@ -260,7 +260,7 @@ pub mod errors { use crate::str::UnicodeEscapeCodepoint; use super::*; - use std::fmt::Write; + use core::fmt::Write; pub struct Strict; diff --git a/crates/common/src/fileutils.rs b/crates/common/src/fileutils.rs index a12c1cd82e5..9ed5e77afbb 100644 --- a/crates/common/src/fileutils.rs +++ b/crates/common/src/fileutils.rs @@ -9,7 +9,7 @@ pub use windows::{StatStruct, fstat}; #[cfg(not(windows))] pub fn fstat(fd: crate::crt_fd::Borrowed<'_>) -> std::io::Result { - let mut stat = std::mem::MaybeUninit::uninit(); + let mut stat = core::mem::MaybeUninit::uninit(); unsafe { let ret = libc::fstat(fd.as_raw(), stat.as_mut_ptr()); if ret == -1 { @@ -165,7 +165,7 @@ pub mod windows { } fn file_time_to_time_t_nsec(in_ptr: &FILETIME) -> (libc::time_t, libc::c_int) { - let in_val: i64 = unsafe { std::mem::transmute_copy(in_ptr) }; + let in_val: i64 = unsafe { core::mem::transmute_copy(in_ptr) }; let nsec_out = (in_val % 10_000_000) * 100; // FILETIME is in units of 100 nsec. let time_out = (in_val / 10_000_000) - SECS_BETWEEN_EPOCHS; (time_out, nsec_out as _) @@ -204,7 +204,7 @@ pub mod windows { let st_nlink = info.nNumberOfLinks as i32; let st_ino = if let Some(id_info) = id_info { - let file_id: [u64; 2] = unsafe { std::mem::transmute_copy(&id_info.FileId) }; + let file_id: [u64; 2] = unsafe { core::mem::transmute_copy(&id_info.FileId) }; file_id } else { let ino = ((info.nFileIndexHigh as u64) << 32) + info.nFileIndexLow as u64; @@ -313,7 +313,7 @@ pub mod windows { unsafe { GetProcAddress(module, name.as_bytes_with_nul().as_ptr()) } { Some(unsafe { - std::mem::transmute::< + core::mem::transmute::< unsafe extern "system" fn() -> isize, unsafe extern "system" fn( *const u16, @@ -441,7 +441,7 @@ pub mod windows { // Open a file using std::fs::File and convert to FILE* // Automatically handles path encoding and EINTR retries pub fn fopen(path: &std::path::Path, mode: &str) -> std::io::Result<*mut libc::FILE> { - use std::ffi::CString; + use alloc::ffi::CString; use std::fs::File; // Currently only supports read mode diff --git a/crates/common/src/format.rs b/crates/common/src/format.rs index 447ae575f48..1afee519aef 100644 --- a/crates/common/src/format.rs +++ b/crates/common/src/format.rs @@ -1,4 +1,6 @@ // spell-checker:ignore ddfe +use core::ops::Deref; +use core::{cmp, str::FromStr}; use itertools::{Itertools, PeekingNext}; use malachite_base::num::basic::floats::PrimitiveFloat; use malachite_bigint::{BigInt, Sign}; @@ -7,8 +9,6 @@ use num_traits::FromPrimitive; use num_traits::{Signed, cast::ToPrimitive}; use rustpython_literal::float; use rustpython_literal::format::Case; -use std::ops::Deref; -use std::{cmp, str::FromStr}; use crate::wtf8::{CodePoint, Wtf8, Wtf8Buf}; @@ -598,7 +598,7 @@ impl FormatSpec { (Some(_), _) => Err(FormatSpecError::NotAllowed("Sign")), (_, true) => Err(FormatSpecError::NotAllowed("Alternate form (#)")), (_, _) => match num.to_u32() { - Some(n) if n <= 0x10ffff => Ok(std::char::from_u32(n).unwrap().to_string()), + Some(n) if n <= 0x10ffff => Ok(core::char::from_u32(n).unwrap().to_string()), Some(_) | None => Err(FormatSpecError::CodeNotInRange), }, }, diff --git a/crates/common/src/hash.rs b/crates/common/src/hash.rs index dcf424f7ba9..40c428d89e3 100644 --- a/crates/common/src/hash.rs +++ b/crates/common/src/hash.rs @@ -1,7 +1,7 @@ +use core::hash::{BuildHasher, Hash, Hasher}; use malachite_bigint::BigInt; use num_traits::ToPrimitive; use siphasher::sip::SipHasher24; -use std::hash::{BuildHasher, Hash, Hasher}; pub type PyHash = i64; pub type PyUHash = u64; @@ -19,9 +19,9 @@ pub const INF: PyHash = 314_159; pub const NAN: PyHash = 0; pub const IMAG: PyHash = MULTIPLIER; pub const ALGO: &str = "siphash24"; -pub const HASH_BITS: usize = std::mem::size_of::() * 8; +pub const HASH_BITS: usize = core::mem::size_of::() * 8; // SipHasher24 takes 2 u64s as a seed -pub const SEED_BITS: usize = std::mem::size_of::() * 2 * 8; +pub const SEED_BITS: usize = core::mem::size_of::() * 2 * 8; // pub const CUTOFF: usize = 7; @@ -134,7 +134,7 @@ pub fn hash_bigint(value: &BigInt) -> PyHash { Some(i) => mod_int(i), None => (value % MODULUS).to_i64().unwrap_or_else(|| unsafe { // SAFETY: MODULUS < i64::MAX, so value % MODULUS is guaranteed to be in the range of i64 - std::hint::unreachable_unchecked() + core::hint::unreachable_unchecked() }), }; fix_sentinel(ret) diff --git a/crates/common/src/int.rs b/crates/common/src/int.rs index ed09cc01a0a..57696e21fe7 100644 --- a/crates/common/src/int.rs +++ b/crates/common/src/int.rs @@ -7,18 +7,18 @@ pub fn true_div(numerator: &BigInt, denominator: &BigInt) -> f64 { let rational = Rational::from_integers_ref(numerator.into(), denominator.into()); match rational.rounding_into(RoundingMode::Nearest) { // returned value is $t::MAX but still less than the original - (val, std::cmp::Ordering::Less) if val == f64::MAX => f64::INFINITY, + (val, core::cmp::Ordering::Less) if val == f64::MAX => f64::INFINITY, // returned value is $t::MIN but still greater than the original - (val, std::cmp::Ordering::Greater) if val == f64::MIN => f64::NEG_INFINITY, + (val, core::cmp::Ordering::Greater) if val == f64::MIN => f64::NEG_INFINITY, (val, _) => val, } } pub fn float_to_ratio(value: f64) -> Option<(BigInt, BigInt)> { - let sign = match std::cmp::PartialOrd::partial_cmp(&value, &0.0)? { - std::cmp::Ordering::Less => Sign::Minus, - std::cmp::Ordering::Equal => return Some((BigInt::zero(), BigInt::one())), - std::cmp::Ordering::Greater => Sign::Plus, + let sign = match core::cmp::PartialOrd::partial_cmp(&value, &0.0)? { + core::cmp::Ordering::Less => Sign::Minus, + core::cmp::Ordering::Equal => return Some((BigInt::zero(), BigInt::one())), + core::cmp::Ordering::Greater => Sign::Plus, }; Rational::try_from(value).ok().map(|x| { let (numer, denom) = x.into_numerator_and_denominator(); diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index c99ba0286a4..0181562d043 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -2,6 +2,8 @@ #![cfg_attr(all(target_os = "wasi", target_env = "p2"), feature(wasip2))] +extern crate alloc; + #[macro_use] mod macros; pub use macros::*; diff --git a/crates/common/src/linked_list.rs b/crates/common/src/linked_list.rs index 8afc1478e6b..fb2b1260346 100644 --- a/crates/common/src/linked_list.rs +++ b/crates/common/src/linked_list.rs @@ -253,7 +253,7 @@ impl LinkedList { // === rustpython additions === pub fn iter(&self) -> impl Iterator { - std::iter::successors(self.head, |node| unsafe { + core::iter::successors(self.head, |node| unsafe { L::pointers(*node).as_ref().get_next() }) .map(|ptr| unsafe { ptr.as_ref() }) diff --git a/crates/common/src/lock/cell_lock.rs b/crates/common/src/lock/cell_lock.rs index 25a5cfedba1..73d722a8fdb 100644 --- a/crates/common/src/lock/cell_lock.rs +++ b/crates/common/src/lock/cell_lock.rs @@ -1,9 +1,9 @@ // spell-checker:ignore upgradably sharedly +use core::{cell::Cell, num::NonZero}; use lock_api::{ GetThreadId, RawMutex, RawRwLock, RawRwLockDowngrade, RawRwLockRecursive, RawRwLockUpgrade, RawRwLockUpgradeDowngrade, }; -use std::{cell::Cell, num::NonZero}; pub struct RawCellMutex { locked: Cell, diff --git a/crates/common/src/lock/immutable_mutex.rs b/crates/common/src/lock/immutable_mutex.rs index 81c5c93be71..2013cf1c60d 100644 --- a/crates/common/src/lock/immutable_mutex.rs +++ b/crates/common/src/lock/immutable_mutex.rs @@ -1,7 +1,8 @@ #![allow(clippy::needless_lifetimes)] +use alloc::fmt; +use core::{marker::PhantomData, ops::Deref}; use lock_api::{MutexGuard, RawMutex}; -use std::{fmt, marker::PhantomData, ops::Deref}; /// A mutex guard that has an exclusive lock, but only an immutable reference; useful if you /// need to map a mutex guard with a function that returns an `&T`. Construct using the @@ -22,7 +23,7 @@ impl<'a, R: RawMutex, T: ?Sized> MapImmutable<'a, R, T> for MutexGuard<'a, R, T> { let raw = unsafe { MutexGuard::mutex(&s).raw() }; let data = f(&s) as *const U; - std::mem::forget(s); + core::mem::forget(s); ImmutableMappedMutexGuard { raw, data, @@ -38,7 +39,7 @@ impl<'a, R: RawMutex, T: ?Sized> ImmutableMappedMutexGuard<'a, R, T> { { let raw = s.raw; let data = f(&s) as *const U; - std::mem::forget(s); + core::mem::forget(s); ImmutableMappedMutexGuard { raw, data, diff --git a/crates/common/src/lock/thread_mutex.rs b/crates/common/src/lock/thread_mutex.rs index 2949a3c6c14..67ffc89245d 100644 --- a/crates/common/src/lock/thread_mutex.rs +++ b/crates/common/src/lock/thread_mutex.rs @@ -1,14 +1,14 @@ #![allow(clippy::needless_lifetimes)] -use lock_api::{GetThreadId, GuardNoSend, RawMutex}; -use std::{ +use alloc::fmt; +use core::{ cell::UnsafeCell, - fmt, marker::PhantomData, ops::{Deref, DerefMut}, ptr::NonNull, sync::atomic::{AtomicUsize, Ordering}, }; +use lock_api::{GetThreadId, GuardNoSend, RawMutex}; // based off ReentrantMutex from lock_api @@ -174,7 +174,7 @@ impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> ThreadMutexGuard<'a, R, G, T> { ) -> MappedThreadMutexGuard<'a, R, G, U> { let data = f(&mut s).into(); let mu = &s.mu.raw; - std::mem::forget(s); + core::mem::forget(s); MappedThreadMutexGuard { mu, data, @@ -188,7 +188,7 @@ impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> ThreadMutexGuard<'a, R, G, T> { if let Some(data) = f(&mut s) { let data = data.into(); let mu = &s.mu.raw; - std::mem::forget(s); + core::mem::forget(s); Ok(MappedThreadMutexGuard { mu, data, @@ -241,7 +241,7 @@ impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> MappedThreadMutexGuard<'a, R, G ) -> MappedThreadMutexGuard<'a, R, G, U> { let data = f(&mut s).into(); let mu = s.mu; - std::mem::forget(s); + core::mem::forget(s); MappedThreadMutexGuard { mu, data, @@ -255,7 +255,7 @@ impl<'a, R: RawMutex, G: GetThreadId, T: ?Sized> MappedThreadMutexGuard<'a, R, G if let Some(data) = f(&mut s) { let data = data.into(); let mu = s.mu; - std::mem::forget(s); + core::mem::forget(s); Ok(MappedThreadMutexGuard { mu, data, diff --git a/crates/common/src/os.rs b/crates/common/src/os.rs index e77a81fd94f..3e09a29210a 100644 --- a/crates/common/src/os.rs +++ b/crates/common/src/os.rs @@ -1,7 +1,8 @@ // spell-checker:disable // TODO: we can move more os-specific bindings/interfaces from stdlib::{os, posix, nt} to here -use std::{io, process::ExitCode, str::Utf8Error}; +use core::str::Utf8Error; +use std::{io, process::ExitCode}; /// Convert exit code to std::process::ExitCode /// diff --git a/crates/common/src/rc.rs b/crates/common/src/rc.rs index 40c7cf97a8d..9e4cca228fd 100644 --- a/crates/common/src/rc.rs +++ b/crates/common/src/rc.rs @@ -1,7 +1,7 @@ #[cfg(not(feature = "threading"))] -use std::rc::Rc; +use alloc::rc::Rc; #[cfg(feature = "threading")] -use std::sync::Arc; +use alloc::sync::Arc; // type aliases instead of new-types because you can't do `fn method(self: PyRc)` with a // newtype; requires the arbitrary_self_types unstable feature diff --git a/crates/common/src/str.rs b/crates/common/src/str.rs index 2d867130edd..155012ed21f 100644 --- a/crates/common/src/str.rs +++ b/crates/common/src/str.rs @@ -4,8 +4,8 @@ use crate::format::CharLen; use crate::wtf8::{CodePoint, Wtf8, Wtf8Buf}; use ascii::{AsciiChar, AsciiStr, AsciiString}; use core::fmt; +use core::ops::{Bound, RangeBounds}; use core::sync::atomic::Ordering::Relaxed; -use std::ops::{Bound, RangeBounds}; #[cfg(not(target_arch = "wasm32"))] #[allow(non_camel_case_types)] @@ -22,7 +22,7 @@ pub enum StrKind { Wtf8, } -impl std::ops::BitOr for StrKind { +impl core::ops::BitOr for StrKind { type Output = Self; fn bitor(self, other: Self) -> Self { @@ -128,7 +128,7 @@ impl From for StrLen { } impl fmt::Debug for StrLen { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { let len = self.0.load(Relaxed); if len == usize::MAX { f.write_str("") @@ -262,7 +262,7 @@ impl StrData { pub fn as_str(&self) -> Option<&str> { self.kind .is_utf8() - .then(|| unsafe { std::str::from_utf8_unchecked(self.data.as_bytes()) }) + .then(|| unsafe { core::str::from_utf8_unchecked(self.data.as_bytes()) }) } pub fn as_ascii(&self) -> Option<&AsciiStr> { @@ -282,7 +282,7 @@ impl StrData { PyKindStr::Ascii(unsafe { AsciiStr::from_ascii_unchecked(self.data.as_bytes()) }) } StrKind::Utf8 => { - PyKindStr::Utf8(unsafe { std::str::from_utf8_unchecked(self.data.as_bytes()) }) + PyKindStr::Utf8(unsafe { core::str::from_utf8_unchecked(self.data.as_bytes()) }) } StrKind::Wtf8 => PyKindStr::Wtf8(&self.data), } @@ -327,8 +327,8 @@ impl StrData { } } -impl std::fmt::Display for StrData { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Display for StrData { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { self.data.fmt(f) } } @@ -421,7 +421,7 @@ pub fn zfill(bytes: &[u8], width: usize) -> Vec { }; let mut filled = Vec::new(); filled.extend_from_slice(sign); - filled.extend(std::iter::repeat_n(b'0', width - bytes.len())); + filled.extend(core::iter::repeat_n(b'0', width - bytes.len())); filled.extend_from_slice(s); filled } @@ -465,7 +465,8 @@ impl fmt::Display for UnicodeEscapeCodepoint { } pub mod levenshtein { - use std::{cell::RefCell, thread_local}; + use core::cell::RefCell; + use std::thread_local; pub const MOVE_COST: usize = 2; const CASE_COST: usize = 1; @@ -524,9 +525,9 @@ pub mod levenshtein { } if b_end < a_end { - std::mem::swap(&mut a_bytes, &mut b_bytes); - std::mem::swap(&mut a_begin, &mut b_begin); - std::mem::swap(&mut a_end, &mut b_end); + core::mem::swap(&mut a_bytes, &mut b_bytes); + core::mem::swap(&mut a_begin, &mut b_begin); + core::mem::swap(&mut a_end, &mut b_end); } if (b_end - a_end) * MOVE_COST > max_cost { diff --git a/crates/compiler-core/src/bytecode.rs b/crates/compiler-core/src/bytecode.rs index 8df5d9caf6f..5569fa2012b 100644 --- a/crates/compiler-core/src/bytecode.rs +++ b/crates/compiler-core/src/bytecode.rs @@ -5,12 +5,13 @@ use crate::{ marshal::MarshalError, {OneIndexed, SourceLocation}, }; +use alloc::{collections::BTreeSet, fmt}; use bitflags::bitflags; +use core::{hash, marker::PhantomData, mem, num::NonZeroU8, ops::Deref}; use itertools::Itertools; use malachite_bigint::BigInt; use num_complex::Complex64; use rustpython_wtf8::{Wtf8, Wtf8Buf}; -use std::{collections::BTreeSet, fmt, hash, marker::PhantomData, mem, num::NonZeroU8, ops::Deref}; /// Oparg values for [`Instruction::ConvertValue`]. /// @@ -506,7 +507,7 @@ impl Eq for Arg {} impl fmt::Debug for Arg { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Arg<{}>", std::any::type_name::()) + write!(f, "Arg<{}>", core::any::type_name::()) } } @@ -880,7 +881,7 @@ impl From for u8 { #[inline] fn from(ins: Instruction) -> Self { // SAFETY: there's no padding bits - unsafe { std::mem::transmute::(ins) } + unsafe { core::mem::transmute::(ins) } } } @@ -890,7 +891,7 @@ impl TryFrom for Instruction { #[inline] fn try_from(value: u8) -> Result { if value <= u8::from(LAST_INSTRUCTION) { - Ok(unsafe { std::mem::transmute::(value) }) + Ok(unsafe { core::mem::transmute::(value) }) } else { Err(MarshalError::InvalidBytecode) } @@ -1027,7 +1028,7 @@ impl PartialEq for ConstantData { (Boolean { value: a }, Boolean { value: b }) => a == b, (Str { value: a }, Str { value: b }) => a == b, (Bytes { value: a }, Bytes { value: b }) => a == b, - (Code { code: a }, Code { code: b }) => std::ptr::eq(a.as_ref(), b.as_ref()), + (Code { code: a }, Code { code: b }) => core::ptr::eq(a.as_ref(), b.as_ref()), (Tuple { elements: a }, Tuple { elements: b }) => a == b, (None, None) => true, (Ellipsis, Ellipsis) => true, @@ -1053,7 +1054,7 @@ impl hash::Hash for ConstantData { Boolean { value } => value.hash(state), Str { value } => value.hash(state), Bytes { value } => value.hash(state), - Code { code } => std::ptr::hash(code.as_ref(), state), + Code { code } => core::ptr::hash(code.as_ref(), state), Tuple { elements } => elements.hash(state), None => {} Ellipsis => {} diff --git a/crates/compiler-core/src/lib.rs b/crates/compiler-core/src/lib.rs index 08cdc0ec21f..11246f6f44c 100644 --- a/crates/compiler-core/src/lib.rs +++ b/crates/compiler-core/src/lib.rs @@ -1,6 +1,8 @@ #![doc(html_logo_url = "https://raw.githubusercontent.com/RustPython/RustPython/main/logo.png")] #![doc(html_root_url = "https://docs.rs/rustpython-compiler-core/")] +extern crate alloc; + pub mod bytecode; pub mod frozen; pub mod marshal; diff --git a/crates/compiler-core/src/marshal.rs b/crates/compiler-core/src/marshal.rs index 39e48071678..b30894ea065 100644 --- a/crates/compiler-core/src/marshal.rs +++ b/crates/compiler-core/src/marshal.rs @@ -1,8 +1,8 @@ use crate::{OneIndexed, SourceLocation, bytecode::*}; +use core::convert::Infallible; use malachite_bigint::{BigInt, Sign}; use num_complex::Complex64; use rustpython_wtf8::Wtf8; -use std::convert::Infallible; pub const FORMAT_VERSION: u32 = 4; @@ -20,8 +20,8 @@ pub enum MarshalError { BadType, } -impl std::fmt::Display for MarshalError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Display for MarshalError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { Self::Eof => f.write_str("unexpected end of data"), Self::InvalidBytecode => f.write_str("invalid bytecode"), @@ -32,15 +32,15 @@ impl std::fmt::Display for MarshalError { } } -impl From for MarshalError { - fn from(_: std::str::Utf8Error) -> Self { +impl From for MarshalError { + fn from(_: core::str::Utf8Error) -> Self { Self::InvalidUtf8 } } -impl std::error::Error for MarshalError {} +impl core::error::Error for MarshalError {} -type Result = std::result::Result; +type Result = core::result::Result; #[repr(u8)] enum Type { @@ -119,7 +119,7 @@ pub trait Read { } fn read_str(&mut self, len: u32) -> Result<&str> { - Ok(std::str::from_utf8(self.read_slice(len)?)?) + Ok(core::str::from_utf8(self.read_slice(len)?)?) } fn read_wtf8(&mut self, len: u32) -> Result<&Wtf8> { @@ -147,7 +147,7 @@ pub(crate) trait ReadBorrowed<'a>: Read { fn read_slice_borrow(&mut self, n: u32) -> Result<&'a [u8]>; fn read_str_borrow(&mut self, len: u32) -> Result<&'a str> { - Ok(std::str::from_utf8(self.read_slice_borrow(len)?)?) + Ok(core::str::from_utf8(self.read_slice_borrow(len)?)?) } } diff --git a/crates/compiler-core/src/mode.rs b/crates/compiler-core/src/mode.rs index 35e9e77f590..f2b19d677be 100644 --- a/crates/compiler-core/src/mode.rs +++ b/crates/compiler-core/src/mode.rs @@ -7,7 +7,7 @@ pub enum Mode { BlockExpr, } -impl std::str::FromStr for Mode { +impl core::str::FromStr for Mode { type Err = ModeParseError; // To support `builtins.compile()` `mode` argument @@ -25,8 +25,8 @@ impl std::str::FromStr for Mode { #[derive(Debug)] pub struct ModeParseError; -impl std::fmt::Display for ModeParseError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Display for ModeParseError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, r#"mode must be "exec", "eval", or "single""#) } } diff --git a/crates/compiler/src/lib.rs b/crates/compiler/src/lib.rs index 84e64f3c27f..7fa695c0c71 100644 --- a/crates/compiler/src/lib.rs +++ b/crates/compiler/src/lib.rs @@ -28,8 +28,8 @@ pub struct ParseError { pub source_path: String, } -impl std::fmt::Display for ParseError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl ::core::fmt::Display for ParseError { + fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { self.error.fmt(f) } } diff --git a/crates/derive-impl/src/compile_bytecode.rs b/crates/derive-impl/src/compile_bytecode.rs index cdcc89b9984..23c90690dad 100644 --- a/crates/derive-impl/src/compile_bytecode.rs +++ b/crates/derive-impl/src/compile_bytecode.rs @@ -58,11 +58,11 @@ pub trait Compiler { source: &str, mode: Mode, module_name: String, - ) -> Result>; + ) -> Result>; } impl CompilationSource { - fn compile_string D>( + fn compile_string D>( &self, source: &str, mode: Mode, diff --git a/crates/derive-impl/src/from_args.rs b/crates/derive-impl/src/from_args.rs index 4633c9b3aac..667f887e81c 100644 --- a/crates/derive-impl/src/from_args.rs +++ b/crates/derive-impl/src/from_args.rs @@ -18,7 +18,7 @@ enum ParameterKind { impl TryFrom<&Ident> for ParameterKind { type Error = (); - fn try_from(ident: &Ident) -> std::result::Result { + fn try_from(ident: &Ident) -> core::result::Result { Ok(match ident.to_string().as_str() { "positional" => Self::PositionalOnly, "any" => Self::PositionalOrKeyword, @@ -105,12 +105,12 @@ impl ArgAttribute { impl TryFrom<&Field> for ArgAttribute { type Error = syn::Error; - fn try_from(field: &Field) -> std::result::Result { + fn try_from(field: &Field) -> core::result::Result { let mut pyarg_attrs = field .attrs .iter() .filter_map(Self::from_attribute) - .collect::, _>>()?; + .collect::, _>>()?; if pyarg_attrs.len() >= 2 { bail_span!(field, "Multiple pyarg attributes on field") @@ -234,7 +234,7 @@ pub fn impl_from_args(input: DeriveInput) -> Result { fn from_args( vm: &::rustpython_vm::VirtualMachine, args: &mut ::rustpython_vm::function::FuncArgs - ) -> ::std::result::Result { + ) -> ::core::result::Result { Ok(Self { #fields }) } } diff --git a/crates/derive-impl/src/pyclass.rs b/crates/derive-impl/src/pyclass.rs index 55f9c769940..06bbc06cfb2 100644 --- a/crates/derive-impl/src/pyclass.rs +++ b/crates/derive-impl/src/pyclass.rs @@ -4,11 +4,11 @@ use crate::util::{ ItemMeta, ItemMetaInner, ItemNursery, SimpleItemMeta, format_doc, pyclass_ident_and_attrs, pyexception_ident_and_attrs, text_signature, }; +use core::str::FromStr; use proc_macro2::{Delimiter, Group, Span, TokenStream, TokenTree}; use quote::{ToTokens, quote, quote_spanned}; use rustpython_doc::DB; use std::collections::{HashMap, HashSet}; -use std::str::FromStr; use syn::{Attribute, Ident, Item, Result, parse_quote, spanned::Spanned}; use syn_ext::ext::*; use syn_ext::types::*; @@ -25,8 +25,8 @@ enum AttrName { Member, } -impl std::fmt::Display for AttrName { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Display for AttrName { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { let s = match self { Self::Method => "pymethod", Self::ClassMethod => "pyclassmethod", @@ -44,7 +44,7 @@ impl std::fmt::Display for AttrName { impl FromStr for AttrName { type Err = String; - fn from_str(s: &str) -> std::result::Result { + fn from_str(s: &str) -> core::result::Result { Ok(match s { "pymethod" => Self::Method, "pyclassmethod" => Self::ClassMethod, @@ -1488,7 +1488,7 @@ impl ItemMeta for SlotItemMeta { fn from_nested(item_ident: Ident, meta_ident: Ident, mut nested: I) -> Result where - I: std::iter::Iterator, + I: core::iter::Iterator, { let meta_map = if let Some(nested_meta) = nested.next() { match nested_meta { diff --git a/crates/derive-impl/src/pymodule.rs b/crates/derive-impl/src/pymodule.rs index 2d5ff7cb0c2..3689ac97fd8 100644 --- a/crates/derive-impl/src/pymodule.rs +++ b/crates/derive-impl/src/pymodule.rs @@ -5,10 +5,11 @@ use crate::util::{ ErrorVec, ItemMeta, ItemNursery, ModuleItemMeta, SimpleItemMeta, format_doc, iter_use_idents, pyclass_ident_and_attrs, text_signature, }; +use core::str::FromStr; use proc_macro2::{Delimiter, Group, TokenStream, TokenTree}; use quote::{ToTokens, quote, quote_spanned}; use rustpython_doc::DB; -use std::{collections::HashSet, str::FromStr}; +use std::collections::HashSet; use syn::{Attribute, Ident, Item, Result, parse_quote, spanned::Spanned}; use syn_ext::ext::*; use syn_ext::types::PunctuatedNestedMeta; @@ -22,8 +23,8 @@ enum AttrName { StructSequence, } -impl std::fmt::Display for AttrName { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Display for AttrName { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { let s = match self { Self::Function => "pyfunction", Self::Attr => "pyattr", @@ -38,7 +39,7 @@ impl std::fmt::Display for AttrName { impl FromStr for AttrName { type Err = String; - fn from_str(s: &str) -> std::result::Result { + fn from_str(s: &str) -> core::result::Result { Ok(match s { "pyfunction" => Self::Function, "pyattr" => Self::Attr, diff --git a/crates/derive-impl/src/pytraverse.rs b/crates/derive-impl/src/pytraverse.rs index c5c4bbd2704..c4ec3823298 100644 --- a/crates/derive-impl/src/pytraverse.rs +++ b/crates/derive-impl/src/pytraverse.rs @@ -37,7 +37,7 @@ fn field_to_traverse_code(field: &Field) -> Result { .attrs .iter() .filter_map(pytraverse_arg) - .collect::, _>>()?; + .collect::, _>>()?; let do_trace = if pytraverse_attrs.len() > 1 { bail_span!( field, diff --git a/crates/derive-impl/src/util.rs b/crates/derive-impl/src/util.rs index 379adc65b57..6be1fcdf7ad 100644 --- a/crates/derive-impl/src/util.rs +++ b/crates/derive-impl/src/util.rs @@ -97,7 +97,7 @@ pub(crate) struct ContentItemInner { } pub(crate) trait ContentItem { - type AttrName: std::str::FromStr + std::fmt::Display; + type AttrName: core::str::FromStr + core::fmt::Display; fn inner(&self) -> &ContentItemInner; fn index(&self) -> usize { @@ -125,7 +125,7 @@ impl ItemMetaInner { allowed_names: &[&'static str], ) -> Result where - I: std::iter::Iterator, + I: core::iter::Iterator, { let (meta_map, lits) = nested.into_unique_map_and_lits(|path| { if let Some(ident) = path.get_ident() { @@ -243,7 +243,7 @@ impl ItemMetaInner { pub fn _optional_list( &self, key: &str, - ) -> Result>> { + ) -> Result>> { let value = if let Some((_, meta)) = self.meta_map.get(key) { let Meta::List(MetaList { path: _, nested, .. @@ -269,7 +269,7 @@ pub(crate) trait ItemMeta: Sized { fn from_nested(item_ident: Ident, meta_ident: Ident, nested: I) -> Result where - I: std::iter::Iterator, + I: core::iter::Iterator, { Ok(Self::from_inner(ItemMetaInner::from_nested( item_ident, @@ -529,7 +529,7 @@ impl ExceptionItemMeta { } } -impl std::ops::Deref for ExceptionItemMeta { +impl core::ops::Deref for ExceptionItemMeta { type Target = ClassItemMeta; fn deref(&self) -> &Self::Target { &self.0 diff --git a/crates/derive/src/lib.rs b/crates/derive/src/lib.rs index 655ad3b4c9e..5a3ff84c63a 100644 --- a/crates/derive/src/lib.rs +++ b/crates/derive/src/lib.rs @@ -274,7 +274,7 @@ impl derive_impl::Compiler for Compiler { source: &str, mode: rustpython_compiler::Mode, module_name: String, - ) -> Result> { + ) -> Result> { use rustpython_compiler::{CompileOpts, compile}; Ok(compile(source, mode, &module_name, CompileOpts::default())?) } diff --git a/crates/literal/src/escape.rs b/crates/literal/src/escape.rs index 6bdd94e9860..72ceaf60d5b 100644 --- a/crates/literal/src/escape.rs +++ b/crates/literal/src/escape.rs @@ -55,9 +55,9 @@ pub unsafe trait Escape { /// # Safety /// /// This string must only contain printable characters. - unsafe fn write_source(&self, formatter: &mut impl std::fmt::Write) -> std::fmt::Result; - fn write_body_slow(&self, formatter: &mut impl std::fmt::Write) -> std::fmt::Result; - fn write_body(&self, formatter: &mut impl std::fmt::Write) -> std::fmt::Result { + unsafe fn write_source(&self, formatter: &mut impl std::fmt::Write) -> core::fmt::Result; + fn write_body_slow(&self, formatter: &mut impl std::fmt::Write) -> core::fmt::Result; + fn write_body(&self, formatter: &mut impl std::fmt::Write) -> core::fmt::Result { if self.changed() { self.write_body_slow(formatter) } else { @@ -117,7 +117,7 @@ impl<'a> UnicodeEscape<'a> { pub struct StrRepr<'r, 'a>(&'r UnicodeEscape<'a>); impl StrRepr<'_, '_> { - pub fn write(&self, formatter: &mut impl std::fmt::Write) -> std::fmt::Result { + pub fn write(&self, formatter: &mut impl std::fmt::Write) -> core::fmt::Result { let quote = self.0.layout().quote.to_char(); formatter.write_char(quote)?; self.0.write_body(formatter)?; @@ -131,8 +131,8 @@ impl StrRepr<'_, '_> { } } -impl std::fmt::Display for StrRepr<'_, '_> { - fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Display for StrRepr<'_, '_> { + fn fmt(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { self.write(formatter) } } @@ -217,7 +217,7 @@ impl UnicodeEscape<'_> { ch: CodePoint, quote: Quote, formatter: &mut impl std::fmt::Write, - ) -> std::fmt::Result { + ) -> core::fmt::Result { let Some(ch) = ch.to_char() else { return write!(formatter, "\\u{:04x}", ch.to_u32()); }; @@ -260,7 +260,7 @@ unsafe impl Escape for UnicodeEscape<'_> { &self.layout } - unsafe fn write_source(&self, formatter: &mut impl std::fmt::Write) -> std::fmt::Result { + unsafe fn write_source(&self, formatter: &mut impl std::fmt::Write) -> core::fmt::Result { formatter.write_str(unsafe { // SAFETY: this function must be called only when source is printable characters (i.e. no surrogates) std::str::from_utf8_unchecked(self.source.as_bytes()) @@ -268,7 +268,7 @@ unsafe impl Escape for UnicodeEscape<'_> { } #[cold] - fn write_body_slow(&self, formatter: &mut impl std::fmt::Write) -> std::fmt::Result { + fn write_body_slow(&self, formatter: &mut impl std::fmt::Write) -> core::fmt::Result { for ch in self.source.code_points() { Self::write_char(ch, self.layout().quote, formatter)?; } @@ -378,7 +378,7 @@ impl AsciiEscape<'_> { } } - fn write_char(ch: u8, quote: Quote, formatter: &mut impl std::fmt::Write) -> std::fmt::Result { + fn write_char(ch: u8, quote: Quote, formatter: &mut impl std::fmt::Write) -> core::fmt::Result { match ch { b'\t' => formatter.write_str("\\t"), b'\n' => formatter.write_str("\\n"), @@ -404,7 +404,7 @@ unsafe impl Escape for AsciiEscape<'_> { &self.layout } - unsafe fn write_source(&self, formatter: &mut impl std::fmt::Write) -> std::fmt::Result { + unsafe fn write_source(&self, formatter: &mut impl std::fmt::Write) -> core::fmt::Result { formatter.write_str(unsafe { // SAFETY: this function must be called only when source is printable ascii characters std::str::from_utf8_unchecked(self.source) @@ -412,7 +412,7 @@ unsafe impl Escape for AsciiEscape<'_> { } #[cold] - fn write_body_slow(&self, formatter: &mut impl std::fmt::Write) -> std::fmt::Result { + fn write_body_slow(&self, formatter: &mut impl std::fmt::Write) -> core::fmt::Result { for ch in self.source { Self::write_char(*ch, self.layout().quote, formatter)?; } @@ -423,7 +423,7 @@ unsafe impl Escape for AsciiEscape<'_> { pub struct BytesRepr<'r, 'a>(&'r AsciiEscape<'a>); impl BytesRepr<'_, '_> { - pub fn write(&self, formatter: &mut impl std::fmt::Write) -> std::fmt::Result { + pub fn write(&self, formatter: &mut impl std::fmt::Write) -> core::fmt::Result { let quote = self.0.layout().quote.to_char(); formatter.write_char('b')?; formatter.write_char(quote)?; @@ -438,8 +438,8 @@ impl BytesRepr<'_, '_> { } } -impl std::fmt::Display for BytesRepr<'_, '_> { - fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Display for BytesRepr<'_, '_> { + fn fmt(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { self.write(formatter) } } diff --git a/crates/literal/src/float.rs b/crates/literal/src/float.rs index e2bc54a8f1b..4d0d65cbb34 100644 --- a/crates/literal/src/float.rs +++ b/crates/literal/src/float.rs @@ -55,7 +55,7 @@ pub fn format_fixed(precision: usize, magnitude: f64, case: Case, alternate_form match magnitude { magnitude if magnitude.is_finite() => { let point = decimal_point_or_empty(precision, alternate_form); - let precision = std::cmp::min(precision, u16::MAX as usize); + let precision = core::cmp::min(precision, u16::MAX as usize); format!("{magnitude:.precision$}{point}") } magnitude if magnitude.is_nan() => format_nan(case), diff --git a/crates/sre_engine/src/engine.rs b/crates/sre_engine/src/engine.rs index 9cc2e4788a5..f1f25a2d920 100644 --- a/crates/sre_engine/src/engine.rs +++ b/crates/sre_engine/src/engine.rs @@ -6,8 +6,8 @@ use crate::string::{ }; use super::{MAXREPEAT, SreAtCode, SreCatCode, SreInfo, SreOpcode, StrDrive, StringCursor}; +use core::{convert::TryFrom, ptr::null}; use optional::Optioned; -use std::{convert::TryFrom, ptr::null}; #[derive(Debug, Clone, Copy)] pub struct Request<'a, S> { @@ -27,8 +27,8 @@ impl<'a, S: StrDrive> Request<'a, S> { pattern_codes: &'a [u32], match_all: bool, ) -> Self { - let end = std::cmp::min(end, string.count()); - let start = std::cmp::min(start, end); + let end = core::cmp::min(end, string.count()); + let start = core::cmp::min(start, end); Self { string, @@ -1332,7 +1332,7 @@ fn _count( ctx: &mut MatchContext, max_count: usize, ) -> usize { - let max_count = std::cmp::min(max_count, ctx.remaining_chars(req)); + let max_count = core::cmp::min(max_count, ctx.remaining_chars(req)); let end = ctx.cursor.position + max_count; let opcode = SreOpcode::try_from(ctx.peek_code(req, 0)).unwrap(); diff --git a/crates/sre_engine/src/string.rs b/crates/sre_engine/src/string.rs index 0d3325b6a1d..489819bfb3e 100644 --- a/crates/sre_engine/src/string.rs +++ b/crates/sre_engine/src/string.rs @@ -9,7 +9,7 @@ pub struct StringCursor { impl Default for StringCursor { fn default() -> Self { Self { - ptr: std::ptr::null(), + ptr: core::ptr::null(), position: 0, } } diff --git a/crates/stdlib/src/array.rs b/crates/stdlib/src/array.rs index b51bc02d3fb..b7a6fbd8b4f 100644 --- a/crates/stdlib/src/array.rs +++ b/crates/stdlib/src/array.rs @@ -68,11 +68,12 @@ mod array { }, }, }; + use alloc::fmt; + use core::cmp::Ordering; use itertools::Itertools; use num_traits::ToPrimitive; use rustpython_common::wtf8::{CodePoint, Wtf8, Wtf8Buf}; - use std::{cmp::Ordering, fmt, os::raw}; - + use std::os::raw; macro_rules! def_array_enum { ($(($n:ident, $t:ty, $c:literal, $scode:literal)),*$(,)?) => { #[derive(Debug, Clone)] @@ -104,14 +105,14 @@ mod array { const fn itemsize_of_typecode(c: char) -> Option { match c { - $($c => Some(std::mem::size_of::<$t>()),)* + $($c => Some(core::mem::size_of::<$t>()),)* _ => None, } } const fn itemsize(&self) -> usize { match self { - $(ArrayContentType::$n(_) => std::mem::size_of::<$t>(),)* + $(ArrayContentType::$n(_) => core::mem::size_of::<$t>(),)* } } @@ -201,10 +202,10 @@ mod array { if v.is_empty() { // safe because every configuration of bytes for the types we // support are valid - let b = std::mem::ManuallyDrop::new(b); + let b = core::mem::ManuallyDrop::new(b); let ptr = b.as_ptr() as *mut $t; - let len = b.len() / std::mem::size_of::<$t>(); - let capacity = b.capacity() / std::mem::size_of::<$t>(); + let len = b.len() / core::mem::size_of::<$t>(); + let capacity = b.capacity() / core::mem::size_of::<$t>(); *v = unsafe { Vec::from_raw_parts(ptr, len, capacity) }; } else { self.frombytes(&b); @@ -220,8 +221,8 @@ mod array { // support are valid if b.len() > 0 { let ptr = b.as_ptr() as *const $t; - let ptr_len = b.len() / std::mem::size_of::<$t>(); - let slice = unsafe { std::slice::from_raw_parts(ptr, ptr_len) }; + let ptr_len = b.len() / core::mem::size_of::<$t>(); + let slice = unsafe { core::slice::from_raw_parts(ptr, ptr_len) }; v.extend_from_slice(slice); } })* @@ -249,8 +250,8 @@ mod array { $(ArrayContentType::$n(v) => { // safe because we're just reading memory as bytes let ptr = v.as_ptr() as *const u8; - let ptr_len = v.len() * std::mem::size_of::<$t>(); - unsafe { std::slice::from_raw_parts(ptr, ptr_len) } + let ptr_len = v.len() * core::mem::size_of::<$t>(); + unsafe { core::slice::from_raw_parts(ptr, ptr_len) } })* } } @@ -260,8 +261,8 @@ mod array { $(ArrayContentType::$n(v) => { // safe because we're just reading memory as bytes let ptr = v.as_ptr() as *mut u8; - let ptr_len = v.len() * std::mem::size_of::<$t>(); - unsafe { std::slice::from_raw_parts_mut(ptr, ptr_len) } + let ptr_len = v.len() * core::mem::size_of::<$t>(); + unsafe { core::slice::from_raw_parts_mut(ptr, ptr_len) } })* } } @@ -785,18 +786,18 @@ mod array { if item_size == 2 { // safe because every configuration of bytes for the types we support are valid let utf16 = unsafe { - std::slice::from_raw_parts( + core::slice::from_raw_parts( bytes.as_ptr() as *const u16, - bytes.len() / std::mem::size_of::(), + bytes.len() / core::mem::size_of::(), ) }; Ok(Wtf8Buf::from_wide(utf16)) } else { // safe because every configuration of bytes for the types we support are valid let chars = unsafe { - std::slice::from_raw_parts( + core::slice::from_raw_parts( bytes.as_ptr() as *const u32, - bytes.len() / std::mem::size_of::(), + bytes.len() / core::mem::size_of::(), ) }; chars @@ -1516,7 +1517,7 @@ mod array { impl MachineFormatCode { fn from_typecode(code: char) -> Option { - use std::mem::size_of; + use core::mem::size_of; let signed = code.is_ascii_uppercase(); let big_endian = cfg!(target_endian = "big"); let int_size = match code { @@ -1590,7 +1591,7 @@ mod array { macro_rules! chunk_to_obj { ($BYTE:ident, $TY:ty, $BIG_ENDIAN:ident) => {{ - let b = <[u8; ::std::mem::size_of::<$TY>()]>::try_from($BYTE).unwrap(); + let b = <[u8; ::core::mem::size_of::<$TY>()]>::try_from($BYTE).unwrap(); if $BIG_ENDIAN { <$TY>::from_be_bytes(b) } else { @@ -1601,7 +1602,7 @@ mod array { chunk_to_obj!($BYTE, $TY, $BIG_ENDIAN).to_pyobject($VM) }; ($VM:ident, $BYTE:ident, $SIGNED_TY:ty, $UNSIGNED_TY:ty, $SIGNED:ident, $BIG_ENDIAN:ident) => {{ - let b = <[u8; ::std::mem::size_of::<$SIGNED_TY>()]>::try_from($BYTE).unwrap(); + let b = <[u8; ::core::mem::size_of::<$SIGNED_TY>()]>::try_from($BYTE).unwrap(); match ($SIGNED, $BIG_ENDIAN) { (false, false) => <$UNSIGNED_TY>::from_le_bytes(b).to_pyobject($VM), (false, true) => <$UNSIGNED_TY>::from_be_bytes(b).to_pyobject($VM), diff --git a/crates/stdlib/src/binascii.rs b/crates/stdlib/src/binascii.rs index a2316d3c204..671d1d9e253 100644 --- a/crates/stdlib/src/binascii.rs +++ b/crates/stdlib/src/binascii.rs @@ -359,7 +359,7 @@ mod decl { } _ => unsafe { // quad_pos is only assigned in this match statement to constants - std::hint::unreachable_unchecked() + core::hint::unreachable_unchecked() }, } } diff --git a/crates/stdlib/src/bz2.rs b/crates/stdlib/src/bz2.rs index a2a40953cff..93142e92a68 100644 --- a/crates/stdlib/src/bz2.rs +++ b/crates/stdlib/src/bz2.rs @@ -15,9 +15,10 @@ mod _bz2 { object::PyResult, types::Constructor, }; + use alloc::fmt; use bzip2::{Decompress, Status, write::BzEncoder}; use rustpython_vm::convert::ToPyException; - use std::{fmt, io::Write}; + use std::io::Write; const BUFSIZ: usize = 8192; diff --git a/crates/stdlib/src/cmath.rs b/crates/stdlib/src/cmath.rs index e5d1d55a578..7f975e41719 100644 --- a/crates/stdlib/src/cmath.rs +++ b/crates/stdlib/src/cmath.rs @@ -11,7 +11,7 @@ mod cmath { // Constants #[pyattr] - use std::f64::consts::{E as e, PI as pi, TAU as tau}; + use core::f64::consts::{E as e, PI as pi, TAU as tau}; #[pyattr(name = "inf")] const INF: f64 = f64::INFINITY; #[pyattr(name = "nan")] @@ -93,7 +93,7 @@ mod cmath { z.log( base.into_option() .map(|base| base.re) - .unwrap_or(std::f64::consts::E), + .unwrap_or(core::f64::consts::E), ) } diff --git a/crates/stdlib/src/compression.rs b/crates/stdlib/src/compression.rs index 7f4e3432eab..a857b4e53de 100644 --- a/crates/stdlib/src/compression.rs +++ b/crates/stdlib/src/compression.rs @@ -107,7 +107,7 @@ impl<'a> Chunker<'a> { pub fn advance(&mut self, consumed: usize) { self.data1 = &self.data1[consumed..]; if self.data1.is_empty() { - self.data1 = std::mem::take(&mut self.data2); + self.data1 = core::mem::take(&mut self.data2); } } } @@ -140,7 +140,7 @@ pub fn _decompress_chunks( let chunk = data.chunk(); let flush = calc_flush(chunk.len() == data.len()); loop { - let additional = std::cmp::min(bufsize, max_length - buf.capacity()); + let additional = core::cmp::min(bufsize, max_length - buf.capacity()); if additional == 0 { return Ok((buf, false)); } diff --git a/crates/stdlib/src/contextvars.rs b/crates/stdlib/src/contextvars.rs index f88ce398c1c..731f5d11e0b 100644 --- a/crates/stdlib/src/contextvars.rs +++ b/crates/stdlib/src/contextvars.rs @@ -1,6 +1,6 @@ use crate::vm::{PyRef, VirtualMachine, builtins::PyModule, class::StaticType}; use _contextvars::PyContext; -use std::cell::RefCell; +use core::cell::RefCell; pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { let module = _contextvars::make_module(vm); @@ -31,13 +31,13 @@ mod _contextvars { protocol::{PyMappingMethods, PySequenceMethods}, types::{AsMapping, AsSequence, Constructor, Hashable, Representable}, }; - use crossbeam_utils::atomic::AtomicCell; - use indexmap::IndexMap; - use std::sync::LazyLock; - use std::{ + use core::{ cell::{Cell, RefCell, UnsafeCell}, sync::atomic::Ordering, }; + use crossbeam_utils::atomic::AtomicCell; + use indexmap::IndexMap; + use std::sync::LazyLock; // TODO: Real hamt implementation type Hamt = IndexMap, PyObjectRef, ahash::RandomState>; @@ -90,11 +90,11 @@ mod _contextvars { } } - fn borrow_vars(&self) -> impl std::ops::Deref + '_ { + fn borrow_vars(&self) -> impl core::ops::Deref + '_ { self.inner.vars.hamt.borrow() } - fn borrow_vars_mut(&self) -> impl std::ops::DerefMut + '_ { + fn borrow_vars_mut(&self) -> impl core::ops::DerefMut + '_ { self.inner.vars.hamt.borrow_mut() } @@ -293,13 +293,13 @@ mod _contextvars { #[pytraverse(skip)] cached: AtomicCell>, #[pytraverse(skip)] - cached_id: std::sync::atomic::AtomicUsize, // cached_tsid in CPython + cached_id: core::sync::atomic::AtomicUsize, // cached_tsid in CPython #[pytraverse(skip)] hash: UnsafeCell, } - impl std::fmt::Debug for ContextVar { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + impl core::fmt::Debug for ContextVar { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("ContextVar").finish() } } @@ -308,7 +308,7 @@ mod _contextvars { impl PartialEq for ContextVar { fn eq(&self, other: &Self) -> bool { - std::ptr::eq(self, other) + core::ptr::eq(self, other) } } impl Eq for ContextVar {} @@ -512,9 +512,9 @@ mod _contextvars { } } - impl std::hash::Hash for ContextVar { + impl core::hash::Hash for ContextVar { #[inline] - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { unsafe { *self.hash.get() }.hash(state) } } diff --git a/crates/stdlib/src/csv.rs b/crates/stdlib/src/csv.rs index a62594a9f1b..4f6cbd76828 100644 --- a/crates/stdlib/src/csv.rs +++ b/crates/stdlib/src/csv.rs @@ -12,12 +12,13 @@ mod _csv { raise_if_stop, types::{Constructor, IterNext, Iterable, SelfIter}, }; + use alloc::fmt; use csv_core::Terminator; use itertools::{self, Itertools}; use parking_lot::Mutex; use rustpython_vm::match_class; + use std::collections::HashMap; use std::sync::LazyLock; - use std::{collections::HashMap, fmt}; #[pyattr] const QUOTE_MINIMAL: i32 = QuoteStyle::Minimal as i32; @@ -1006,7 +1007,7 @@ mod _csv { return Err(new_csv_error(vm, "filed too long to read".to_string())); } prev_end = end; - let s = std::str::from_utf8(&buffer[range.clone()]) + let s = core::str::from_utf8(&buffer[range.clone()]) // not sure if this is possible - the input was all strings .map_err(|_e| vm.new_unicode_decode_error("csv not utf8"))?; // Rustpython TODO! @@ -1116,7 +1117,7 @@ mod _csv { loop { handle_res!(writer.terminator(&mut buffer[buffer_offset..])); } - let s = std::str::from_utf8(&buffer[..buffer_offset]) + let s = core::str::from_utf8(&buffer[..buffer_offset]) .map_err(|_| vm.new_unicode_decode_error("csv not utf8"))?; self.write.call((s,), vm) diff --git a/crates/stdlib/src/faulthandler.rs b/crates/stdlib/src/faulthandler.rs index f45c9909c6f..eba5643b866 100644 --- a/crates/stdlib/src/faulthandler.rs +++ b/crates/stdlib/src/faulthandler.rs @@ -7,13 +7,13 @@ mod decl { PyObjectRef, PyResult, VirtualMachine, builtins::PyFloat, frame::Frame, function::OptionalArg, py_io::Write, }; + use alloc::sync::Arc; + use core::sync::atomic::{AtomicBool, AtomicI32, Ordering}; + use core::time::Duration; use parking_lot::{Condvar, Mutex}; #[cfg(any(unix, windows))] use rustpython_common::os::{get_errno, set_errno}; - use std::sync::Arc; - use std::sync::atomic::{AtomicBool, AtomicI32, Ordering}; use std::thread; - use std::time::Duration; /// fault_handler_t #[cfg(unix)] @@ -40,7 +40,7 @@ mod decl { enabled: false, name, // SAFETY: sigaction is a C struct that can be zero-initialized - previous: unsafe { std::mem::zeroed() }, + previous: unsafe { core::mem::zeroed() }, } } } @@ -144,7 +144,8 @@ mod decl { static mut FRAME_SNAPSHOTS: [FrameSnapshot; MAX_SNAPSHOT_FRAMES] = [FrameSnapshot::EMPTY; MAX_SNAPSHOT_FRAMES]; #[cfg(any(unix, windows))] - static SNAPSHOT_COUNT: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0); + static SNAPSHOT_COUNT: core::sync::atomic::AtomicUsize = + core::sync::atomic::AtomicUsize::new(0); // Signal-safe output functions @@ -240,7 +241,7 @@ mod decl { } let thread_id = current_thread_id(); // Use appropriate width based on platform pointer size - dump_hexadecimal(fd, thread_id, std::mem::size_of::() * 2); + dump_hexadecimal(fd, thread_id, core::mem::size_of::() * 2); puts(fd, " (most recent call first):\n"); } @@ -429,7 +430,7 @@ mod decl { } handler.enabled = false; unsafe { - libc::sigaction(handler.signum, &handler.previous, std::ptr::null_mut()); + libc::sigaction(handler.signum, &handler.previous, core::ptr::null_mut()); } } @@ -549,7 +550,7 @@ mod decl { continue; } - let mut action: libc::sigaction = std::mem::zeroed(); + let mut action: libc::sigaction = core::mem::zeroed(); action.sa_sigaction = faulthandler_fatal_error as libc::sighandler_t; // SA_NODEFER flag action.sa_flags = libc::SA_NODEFER; @@ -1051,8 +1052,8 @@ mod decl { #[cfg(not(target_arch = "wasm32"))] unsafe { suppress_crash_report(); - let ptr: *const i32 = std::ptr::null(); - std::ptr::read_volatile(ptr); + let ptr: *const i32 = core::ptr::null(); + core::ptr::read_volatile(ptr); } } @@ -1132,7 +1133,7 @@ mod decl { panic!("Fatal Python error: in new thread"); }); // Wait a bit for the thread to panic - std::thread::sleep(std::time::Duration::from_secs(1)); + std::thread::sleep(core::time::Duration::from_secs(1)); } } @@ -1203,7 +1204,7 @@ mod decl { suppress_crash_report(); unsafe { - RaiseException(args.code, args.flags, 0, std::ptr::null()); + RaiseException(args.code, args.flags, 0, core::ptr::null()); } } } diff --git a/crates/stdlib/src/fcntl.rs b/crates/stdlib/src/fcntl.rs index dc6a0b8171e..822faeeedaa 100644 --- a/crates/stdlib/src/fcntl.rs +++ b/crates/stdlib/src/fcntl.rs @@ -173,7 +173,7 @@ mod fcntl { }; } - let mut l: libc::flock = unsafe { std::mem::zeroed() }; + let mut l: libc::flock = unsafe { core::mem::zeroed() }; l.l_type = if cmd == libc::LOCK_UN { try_into_l_type!(libc::F_UNLCK) } else if (cmd & libc::LOCK_SH) != 0 { diff --git a/crates/stdlib/src/grp.rs b/crates/stdlib/src/grp.rs index 4664d5fc575..9f7e4195509 100644 --- a/crates/stdlib/src/grp.rs +++ b/crates/stdlib/src/grp.rs @@ -10,8 +10,8 @@ mod grp { exceptions, types::PyStructSequence, }; + use core::ptr::NonNull; use nix::unistd; - use std::ptr::NonNull; #[pystruct_sequence_data] struct GroupData { @@ -30,7 +30,7 @@ mod grp { impl GroupData { fn from_unistd_group(group: unistd::Group, vm: &VirtualMachine) -> Self { - let cstr_lossy = |s: std::ffi::CString| { + let cstr_lossy = |s: alloc::ffi::CString| { s.into_string() .unwrap_or_else(|e| e.into_cstring().to_string_lossy().into_owned()) }; diff --git a/crates/stdlib/src/hashlib.rs b/crates/stdlib/src/hashlib.rs index e7b03a2ff12..2ef485b0329 100644 --- a/crates/stdlib/src/hashlib.rs +++ b/crates/stdlib/src/hashlib.rs @@ -91,8 +91,8 @@ pub mod _hashlib { pub ctx: PyRwLock, } - impl std::fmt::Debug for PyHasher { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + impl core::fmt::Debug for PyHasher { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "HASH {}", self.name) } } @@ -164,8 +164,8 @@ pub mod _hashlib { ctx: PyRwLock, } - impl std::fmt::Debug for PyHasherXof { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + impl core::fmt::Debug for PyHasherXof { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "HASHXOF {}", self.name) } } diff --git a/crates/stdlib/src/json.rs b/crates/stdlib/src/json.rs index eb6ed3a5f64..a3fd7972126 100644 --- a/crates/stdlib/src/json.rs +++ b/crates/stdlib/src/json.rs @@ -12,9 +12,9 @@ mod _json { protocol::PyIterReturn, types::{Callable, Constructor}, }; + use core::str::FromStr; use malachite_bigint::BigInt; use rustpython_common::wtf8::Wtf8Buf; - use std::str::FromStr; #[pyattr(name = "make_scanner")] #[pyclass(name = "Scanner", traverse)] @@ -216,7 +216,7 @@ mod _json { let mut buf = Vec::::with_capacity(s.len() + 2); machinery::write_json_string(s, ascii_only, &mut buf) // SAFETY: writing to a vec can't fail - .unwrap_or_else(|_| unsafe { std::hint::unreachable_unchecked() }); + .unwrap_or_else(|_| unsafe { core::hint::unreachable_unchecked() }); // SAFETY: we only output valid utf8 from write_json_string unsafe { String::from_utf8_unchecked(buf) } } diff --git a/crates/stdlib/src/lib.rs b/crates/stdlib/src/lib.rs index 4b463e09c73..6b7796c8bad 100644 --- a/crates/stdlib/src/lib.rs +++ b/crates/stdlib/src/lib.rs @@ -6,6 +6,7 @@ #[macro_use] extern crate rustpython_derive; +extern crate alloc; pub mod array; mod binascii; @@ -103,7 +104,7 @@ use rustpython_common as common; use rustpython_vm as vm; use crate::vm::{builtins, stdlib::StdlibInitFunc}; -use std::borrow::Cow; +use alloc::borrow::Cow; pub fn get_module_inits() -> impl Iterator, StdlibInitFunc)> { macro_rules! modules { diff --git a/crates/stdlib/src/locale.rs b/crates/stdlib/src/locale.rs index 6cca8b9123b..c65f861d208 100644 --- a/crates/stdlib/src/locale.rs +++ b/crates/stdlib/src/locale.rs @@ -41,16 +41,14 @@ use libc::localeconv; #[pymodule] mod _locale { + use alloc::ffi::CString; + use core::{ffi::CStr, ptr}; use rustpython_vm::{ PyObjectRef, PyResult, VirtualMachine, builtins::{PyDictRef, PyIntRef, PyListRef, PyStrRef, PyTypeRef}, convert::ToPyException, function::OptionalArg, }; - use std::{ - ffi::{CStr, CString}, - ptr, - }; #[cfg(all( unix, diff --git a/crates/stdlib/src/lzma.rs b/crates/stdlib/src/lzma.rs index 855a5eae562..b18ac3ee69a 100644 --- a/crates/stdlib/src/lzma.rs +++ b/crates/stdlib/src/lzma.rs @@ -8,6 +8,7 @@ mod _lzma { CompressFlushKind, CompressState, CompressStatusKind, Compressor, DecompressArgs, DecompressError, DecompressState, DecompressStatus, Decompressor, }; + use alloc::fmt; #[pyattr] use lzma_sys::{ LZMA_CHECK_CRC32 as CHECK_CRC32, LZMA_CHECK_CRC64 as CHECK_CRC64, @@ -38,7 +39,6 @@ mod _lzma { use rustpython_vm::function::ArgBytesLike; use rustpython_vm::types::Constructor; use rustpython_vm::{Py, PyObjectRef, PyPayload, PyResult, VirtualMachine}; - use std::fmt; use xz2::stream::{Action, Check, Error, Filters, LzmaOptions, Status, Stream}; #[cfg(windows)] diff --git a/crates/stdlib/src/math.rs b/crates/stdlib/src/math.rs index 62b0ef73ad3..6e139530804 100644 --- a/crates/stdlib/src/math.rs +++ b/crates/stdlib/src/math.rs @@ -10,15 +10,15 @@ mod math { function::{ArgIndex, ArgIntoFloat, ArgIterable, Either, OptionalArg, PosArgs}, identifier, }; + use core::cmp::Ordering; use itertools::Itertools; use malachite_bigint::BigInt; use num_traits::{One, Signed, ToPrimitive, Zero}; use rustpython_common::{float_ops, int::true_div}; - use std::cmp::Ordering; // Constants #[pyattr] - use std::f64::consts::{E as e, PI as pi, TAU as tau}; + use core::f64::consts::{E as e, PI as pi, TAU as tau}; use super::pymath_error_to_exception; #[pyattr(name = "inf")] @@ -136,7 +136,7 @@ mod math { #[pyfunction] fn log(x: PyObjectRef, base: OptionalArg, vm: &VirtualMachine) -> PyResult { - let base = base.map(|b| *b).unwrap_or(std::f64::consts::E); + let base = base.map(|b| *b).unwrap_or(core::f64::consts::E); if base.is_sign_negative() { return Err(vm.new_value_error("math domain error")); } @@ -359,9 +359,9 @@ mod math { .iter() .copied() .map(|x| (x / scale).powi(2)) - .chain(std::iter::once(-norm * norm)) + .chain(core::iter::once(-norm * norm)) // Pairwise summation of floats gives less rounding error than a naive sum. - .tree_reduce(std::ops::Add::add) + .tree_reduce(core::ops::Add::add) .expect("expected at least 1 element"); norm = norm + correction / (2.0 * norm); } @@ -424,12 +424,12 @@ mod math { #[pyfunction] fn degrees(x: ArgIntoFloat) -> f64 { - *x * (180.0 / std::f64::consts::PI) + *x * (180.0 / core::f64::consts::PI) } #[pyfunction] fn radians(x: ArgIntoFloat) -> f64 { - *x * (std::f64::consts::PI / 180.0) + *x * (core::f64::consts::PI / 180.0) } // Hyperbolic functions: @@ -684,7 +684,7 @@ mod math { for j in 0..partials.len() { let mut y: f64 = partials[j]; if x.abs() < y.abs() { - std::mem::swap(&mut x, &mut y); + core::mem::swap(&mut x, &mut y); } // Rounded `x+y` is stored in `hi` with round-off stored in // `lo`. Together `hi+lo` are exactly equal to `x+y`. diff --git a/crates/stdlib/src/mmap.rs b/crates/stdlib/src/mmap.rs index 5309917a999..916ce6f5962 100644 --- a/crates/stdlib/src/mmap.rs +++ b/crates/stdlib/src/mmap.rs @@ -21,11 +21,11 @@ mod mmap { sliceable::{SaturatedSlice, SequenceIndex, SequenceIndexOp}, types::{AsBuffer, AsMapping, AsSequence, Constructor, Representable}, }; + use core::ops::{Deref, DerefMut}; use crossbeam_utils::atomic::AtomicCell; use memmap2::{Mmap, MmapMut, MmapOptions}; use num_traits::Signed; use std::io::{self, Write}; - use std::ops::{Deref, DerefMut}; #[cfg(unix)] use nix::{sys::stat::fstat, unistd}; @@ -1056,7 +1056,7 @@ mod mmap { // 3. Replace the old mmap let old_size = self.size.load(); - let copy_size = std::cmp::min(old_size, newsize); + let copy_size = core::cmp::min(old_size, newsize); // Create new anonymous mmap let mut new_mmap_opts = MmapOptions::new(); diff --git a/crates/stdlib/src/opcode.rs b/crates/stdlib/src/opcode.rs index c355b59df91..bd4b9aa750a 100644 --- a/crates/stdlib/src/opcode.rs +++ b/crates/stdlib/src/opcode.rs @@ -8,7 +8,7 @@ mod opcode { bytecode::Instruction, match_class, }; - use std::ops::Deref; + use core::ops::Deref; struct Opcode(Instruction); diff --git a/crates/stdlib/src/openssl.rs b/crates/stdlib/src/openssl.rs index d352d15a614..38103a9ab05 100644 --- a/crates/stdlib/src/openssl.rs +++ b/crates/stdlib/src/openssl.rs @@ -522,10 +522,10 @@ mod _ssl { // Thread-local storage for VirtualMachine pointer during handshake // SNI callback is only called during handshake which is synchronous thread_local! { - static HANDSHAKE_VM: std::cell::Cell> = const { std::cell::Cell::new(None) }; + static HANDSHAKE_VM: core::cell::Cell> = const { core::cell::Cell::new(None) }; // SSL pointer during handshake - needed because connection lock is held during handshake // and callbacks may need to access SSL without acquiring the lock - static HANDSHAKE_SSL_PTR: std::cell::Cell> = const { std::cell::Cell::new(None) }; + static HANDSHAKE_SSL_PTR: core::cell::Cell> = const { core::cell::Cell::new(None) }; } // RAII guard to set/clear thread-local handshake context @@ -1896,7 +1896,7 @@ mod _ssl { ))); return Err(openssl::error::ErrorStack::get()); } - let len = std::cmp::min(pw.len(), buf.len()); + let len = core::cmp::min(pw.len(), buf.len()); buf[..len].copy_from_slice(&pw[..len]); Ok(len) } @@ -2714,7 +2714,7 @@ mod _ssl { // Use thread-local SSL pointer during handshake to avoid deadlock let ssl_ptr = get_ssl_ptr_for_context_change(&self.connection); unsafe { - let mut out: *const libc::c_uchar = std::ptr::null(); + let mut out: *const libc::c_uchar = core::ptr::null(); let mut outlen: libc::c_uint = 0; sys::SSL_get0_alpn_selected(ssl_ptr, &mut out, &mut outlen); diff --git a/crates/stdlib/src/overlapped.rs b/crates/stdlib/src/overlapped.rs index d8f14baf35e..1c74ee271b9 100644 --- a/crates/stdlib/src/overlapped.rs +++ b/crates/stdlib/src/overlapped.rs @@ -35,7 +35,7 @@ mod _overlapped { #[pyattr] const INVALID_HANDLE_VALUE: isize = - unsafe { std::mem::transmute(windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE) }; + unsafe { core::mem::transmute(windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE) }; #[pyattr] const NULL: isize = 0; @@ -57,8 +57,8 @@ mod _overlapped { unsafe impl Sync for OverlappedInner {} unsafe impl Send for OverlappedInner {} - impl std::fmt::Debug for Overlapped { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + impl core::fmt::Debug for Overlapped { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { let zelf = self.inner.lock(); f.debug_struct("Overlapped") // .field("overlapped", &(self.overlapped as *const _ as usize)) @@ -98,8 +98,8 @@ mod _overlapped { address_length: libc::c_int, } - impl std::fmt::Debug for OverlappedReadFrom { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + impl core::fmt::Debug for OverlappedReadFrom { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("OverlappedReadFrom") .field("result", &self.result) .field("allocated_buffer", &self.allocated_buffer) @@ -119,8 +119,8 @@ mod _overlapped { address_length: libc::c_int, } - impl std::fmt::Debug for OverlappedReadFromInto { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + impl core::fmt::Debug for OverlappedReadFromInto { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("OverlappedReadFromInto") .field("result", &self.result) .field("user_buffer", &self.user_buffer) @@ -226,7 +226,7 @@ mod _overlapped { } #[cfg(target_pointer_width = "32")] - let size = std::cmp::min(size, std::isize::MAX as _); + let size = core::cmp::min(size, std::isize::MAX as _); let buf = vec![0u8; std::cmp::max(size, 1) as usize]; let buf = vm.ctx.new_bytes(buf); @@ -272,10 +272,10 @@ mod _overlapped { if event == INVALID_HANDLE_VALUE { event = unsafe { windows_sys::Win32::System::Threading::CreateEventA( - std::ptr::null(), + core::ptr::null(), Foundation::TRUE, Foundation::FALSE, - std::ptr::null(), + core::ptr::null(), ) as isize }; if event == NULL { @@ -378,11 +378,11 @@ mod _overlapped { let name = widestring::WideCString::from_str(&name).unwrap(); name.as_ptr() } - None => std::ptr::null(), + None => core::ptr::null(), }; let event = unsafe { windows_sys::Win32::System::Threading::CreateEventW( - std::ptr::null(), + core::ptr::null(), manual_reset as _, initial_state as _, name, diff --git a/crates/stdlib/src/posixshmem.rs b/crates/stdlib/src/posixshmem.rs index 2957f16792c..53bf372532d 100644 --- a/crates/stdlib/src/posixshmem.rs +++ b/crates/stdlib/src/posixshmem.rs @@ -4,7 +4,7 @@ pub(crate) use _posixshmem::make_module; #[cfg(all(unix, not(target_os = "redox"), not(target_os = "android")))] #[pymodule] mod _posixshmem { - use std::ffi::CString; + use alloc::ffi::CString; use crate::{ common::os::errno_io_error, diff --git a/crates/stdlib/src/posixsubprocess.rs b/crates/stdlib/src/posixsubprocess.rs index d05b24fd6dd..5dd499abf40 100644 --- a/crates/stdlib/src/posixsubprocess.rs +++ b/crates/stdlib/src/posixsubprocess.rs @@ -13,15 +13,15 @@ use nix::{ unistd::{self, Pid}, }; use std::{ - convert::Infallible as Never, - ffi::{CStr, CString}, io::prelude::*, - marker::PhantomData, - ops::Deref, os::fd::{AsFd, AsRawFd, BorrowedFd, IntoRawFd, OwnedFd, RawFd}, }; use unistd::{Gid, Uid}; +use alloc::ffi::CString; + +use core::{convert::Infallible as Never, ffi::CStr, marker::PhantomData, ops::Deref}; + pub(crate) use _posixsubprocess::make_module; #[pymodule] @@ -87,7 +87,7 @@ impl<'a, T: AsRef> FromIterator<&'a T> for CharPtrVec<'a> { let vec = iter .into_iter() .map(|x| x.as_ref().as_ptr()) - .chain(std::iter::once(std::ptr::null())) + .chain(core::iter::once(core::ptr::null())) .collect(); Self { vec, diff --git a/crates/stdlib/src/resource.rs b/crates/stdlib/src/resource.rs index 052f45e0cad..e6df75e4b01 100644 --- a/crates/stdlib/src/resource.rs +++ b/crates/stdlib/src/resource.rs @@ -9,7 +9,8 @@ mod resource { convert::{ToPyException, ToPyObject}, types::PyStructSequence, }; - use std::{io, mem}; + use core::mem; + use std::io; cfg_if::cfg_if! { if #[cfg(target_os = "android")] { diff --git a/crates/stdlib/src/scproxy.rs b/crates/stdlib/src/scproxy.rs index 1974e7814ae..40267579029 100644 --- a/crates/stdlib/src/scproxy.rs +++ b/crates/stdlib/src/scproxy.rs @@ -22,7 +22,7 @@ mod _scproxy { fn proxy_dict() -> Option> { // Py_BEGIN_ALLOW_THREADS - let proxy_dict = unsafe { SCDynamicStoreCopyProxies(std::ptr::null()) }; + let proxy_dict = unsafe { SCDynamicStoreCopyProxies(core::ptr::null()) }; // Py_END_ALLOW_THREADS if proxy_dict.is_null() { None diff --git a/crates/stdlib/src/select.rs b/crates/stdlib/src/select.rs index 5639a66d2cc..3c2f5e63c7c 100644 --- a/crates/stdlib/src/select.rs +++ b/crates/stdlib/src/select.rs @@ -4,7 +4,8 @@ use crate::vm::{ PyObject, PyObjectRef, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::PyListRef, builtins::PyModule, }; -use std::{io, mem}; +use core::mem; +use std::io; pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { #[cfg(windows)] @@ -158,7 +159,7 @@ impl FdSet { pub fn new() -> Self { // it's just ints, and all the code that's actually // interacting with it is in C, so it's safe to zero - let mut fdset = std::mem::MaybeUninit::zeroed(); + let mut fdset = core::mem::MaybeUninit::zeroed(); unsafe { platform::FD_ZERO(fdset.as_mut_ptr()) }; Self(fdset) } @@ -191,7 +192,7 @@ pub fn select( ) -> io::Result { let timeout = match timeout { Some(tv) => tv as *mut timeval, - None => std::ptr::null_mut(), + None => core::ptr::null_mut(), }; let ret = unsafe { platform::select( @@ -336,12 +337,10 @@ mod decl { function::OptionalArg, stdlib::io::Fildes, }; + use core::{convert::TryFrom, time::Duration}; use libc::pollfd; use num_traits::{Signed, ToPrimitive}; - use std::{ - convert::TryFrom, - time::{Duration, Instant}, - }; + use std::time::Instant; #[derive(Default)] pub(super) struct TimeoutArg(pub Option); @@ -554,8 +553,8 @@ mod decl { stdlib::io::Fildes, types::Constructor, }; + use core::ops::Deref; use rustix::event::epoll::{self, EventData, EventFlags}; - use std::ops::Deref; use std::os::fd::{AsRawFd, IntoRawFd, OwnedFd}; use std::time::Instant; diff --git a/crates/stdlib/src/socket.rs b/crates/stdlib/src/socket.rs index 08b05b56aa8..c2cd676f17e 100644 --- a/crates/stdlib/src/socket.rs +++ b/crates/stdlib/src/socket.rs @@ -22,15 +22,19 @@ mod _socket { types::{Constructor, DefaultConstructor, Initializer, Representable}, utils::ToCString, }; + use core::{ + mem::MaybeUninit, + net::{Ipv4Addr, Ipv6Addr, SocketAddr}, + time::Duration, + }; use crossbeam_utils::atomic::AtomicCell; use num_traits::ToPrimitive; use socket2::Socket; use std::{ ffi, io::{self, Read, Write}, - mem::MaybeUninit, - net::{self, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, ToSocketAddrs}, - time::{Duration, Instant}, + net::{self, Shutdown, ToSocketAddrs}, + time::Instant, }; #[cfg(unix)] @@ -795,7 +799,7 @@ mod _socket { sock: PyRwLock>, } - const _: () = assert!(std::mem::size_of::>() == std::mem::size_of::()); + const _: () = assert!(core::mem::size_of::>() == core::mem::size_of::()); impl Default for PySocket { fn default() -> Self { @@ -1099,7 +1103,7 @@ mod _socket { Some(errcode!(ENOTSOCK)) | Some(errcode!(EBADF)) ) => { - std::mem::forget(sock); + core::mem::forget(sock); return Err(e.into()); } _ => {} @@ -1409,7 +1413,7 @@ mod _socket { cmsgs: &[(i32, i32, ArgBytesLike)], vm: &VirtualMachine, ) -> PyResult> { - use std::{mem, ptr}; + use core::{mem, ptr}; if cmsgs.is_empty() { return Ok(vec![]); @@ -1535,7 +1539,7 @@ mod _socket { let buflen = buflen.unwrap_or(0); if buflen == 0 { let mut flag: libc::c_int = 0; - let mut flagsize = std::mem::size_of::() as _; + let mut flagsize = core::mem::size_of::() as _; let ret = unsafe { c::getsockopt( fd as _, @@ -1595,11 +1599,11 @@ mod _socket { level, name, val as *const i32 as *const _, - std::mem::size_of::() as _, + core::mem::size_of::() as _, ) }, (None, OptionalArg::Present(optlen)) => unsafe { - c::setsockopt(fd as _, level, name, std::ptr::null(), optlen as _) + c::setsockopt(fd as _, level, name, core::ptr::null(), optlen as _) }, _ => { return Err(vm @@ -1651,7 +1655,7 @@ mod _socket { } impl ToSocketAddrs for Address { - type Iter = std::vec::IntoIter; + type Iter = alloc::vec::IntoIter; fn to_socket_addrs(&self) -> io::Result { (self.host.as_str(), self.port).to_socket_addrs() } @@ -1767,7 +1771,7 @@ mod _socket { } fn cstr_opt_as_ptr(x: &OptionalArg) -> *const libc::c_char { - x.as_ref().map_or_else(std::ptr::null, |s| s.as_ptr()) + x.as_ref().map_or_else(core::ptr::null, |s| s.as_ptr()) } #[pyfunction] @@ -1952,7 +1956,7 @@ mod _socket { }; let host = opts.host.as_ref().map(|s| s.as_str()); - let port = opts.port.as_ref().map(|p| -> std::borrow::Cow<'_, str> { + let port = opts.port.as_ref().map(|p| -> alloc::borrow::Cow<'_, str> { match p { Either::A(s) => s.as_str().into(), Either::B(i) => i.to_string().into(), @@ -2316,7 +2320,7 @@ mod _socket { .state .codec_registry .encode_text(pyname, "idna", None, vm)?; - let name = std::str::from_utf8(name.as_bytes()) + let name = core::str::from_utf8(name.as_bytes()) .map_err(|_| vm.new_runtime_error("idna output is not utf8"))?; let mut res = dns_lookup::getaddrinfo(Some(name), None, Some(hints)) .map_err(|e| convert_socket_error(vm, e, SocketError::GaiError))?; @@ -2472,7 +2476,7 @@ mod _socket { #[pyfunction] fn dup(x: PyObjectRef, vm: &VirtualMachine) -> Result { let sock = get_raw_sock(x, vm)?; - let sock = std::mem::ManuallyDrop::new(sock_from_raw(sock, vm)?); + let sock = core::mem::ManuallyDrop::new(sock_from_raw(sock, vm)?); let newsock = sock.try_clone()?; let fd = into_sock_fileno(newsock); #[cfg(windows)] diff --git a/crates/stdlib/src/sqlite.rs b/crates/stdlib/src/sqlite.rs index 103a827e99a..3a82787cd8f 100644 --- a/crates/stdlib/src/sqlite.rs +++ b/crates/stdlib/src/sqlite.rs @@ -844,7 +844,7 @@ mod _sqlite { } impl Debug for Connection { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { write!(f, "Sqlite3 Connection") } } @@ -2583,7 +2583,7 @@ mod _sqlite { } impl Debug for Statement { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { write!( f, "{} Statement", diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index 16449e2d019..b90176a62fa 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -52,14 +52,12 @@ mod _ssl { use super::error::{ PySSLEOFError, PySSLError, create_ssl_want_read_error, create_ssl_want_write_error, }; - use std::{ - collections::HashMap, - sync::{ - Arc, - atomic::{AtomicUsize, Ordering}, - }, - time::{Duration, SystemTime}, + use alloc::sync::Arc; + use core::{ + sync::atomic::{AtomicUsize, Ordering}, + time::Duration, }; + use std::{collections::HashMap, time::SystemTime}; // Rustls imports use parking_lot::{Mutex as ParkingMutex, RwLock as ParkingRwLock}; @@ -3124,7 +3122,7 @@ mod _ssl { // When server_hostname=None, use an IP address to suppress SNI // no hostname = no SNI extension ServerName::IpAddress( - std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)).into(), + core::net::IpAddr::V4(core::net::Ipv4Addr::new(127, 0, 0, 1)).into(), ) }; @@ -3385,7 +3383,7 @@ mod _ssl { let mut written = 0; while written < data.len() { - let chunk_end = std::cmp::min(written + CHUNK_SIZE, data.len()); + let chunk_end = core::cmp::min(written + CHUNK_SIZE, data.len()); let chunk = &data[written..chunk_end]; // Write chunk to TLS layer @@ -4176,8 +4174,8 @@ mod _ssl { #[pygetset] fn id(&self, vm: &VirtualMachine) -> PyBytesRef { // Return session ID (hash of session data for uniqueness) + use core::hash::{Hash, Hasher}; use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; let mut hasher = DefaultHasher::new(); self.session_data.hash(&mut hasher); @@ -4487,7 +4485,7 @@ mod _ssl { let mut result = Vec::new(); - let mut crl_context: *const CRL_CONTEXT = std::ptr::null(); + let mut crl_context: *const CRL_CONTEXT = core::ptr::null(); loop { crl_context = unsafe { CertEnumCRLsInStore(store, crl_context) }; if crl_context.is_null() { @@ -4587,8 +4585,8 @@ mod _ssl { // Implement Hashable trait for PySSLCertificate impl Hashable for PySSLCertificate { fn hash(zelf: &Py, _vm: &VirtualMachine) -> PyResult { + use core::hash::{Hash, Hasher}; use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; let mut hasher = DefaultHasher::new(); zelf.der_bytes.hash(&mut hasher); diff --git a/crates/stdlib/src/ssl/cert.rs b/crates/stdlib/src/ssl/cert.rs index b3cb7d6c14e..cd39972cf41 100644 --- a/crates/stdlib/src/ssl/cert.rs +++ b/crates/stdlib/src/ssl/cert.rs @@ -9,6 +9,7 @@ //! - Building and verifying certificate chains //! - Loading certificates from files, directories, and bytes +use alloc::sync::Arc; use chrono::{DateTime, Utc}; use parking_lot::RwLock as ParkingRwLock; use rustls::{ @@ -19,7 +20,6 @@ use rustls::{ }; use rustpython_vm::{PyObjectRef, PyResult, VirtualMachine}; use std::collections::HashSet; -use std::sync::Arc; use x509_parser::prelude::*; use super::compat::{VERIFY_X509_PARTIAL_CHAIN, VERIFY_X509_STRICT}; @@ -51,8 +51,9 @@ const ALL_SIGNATURE_SCHEMES: &[SignatureScheme] = &[ /// operations, reducing code duplication and ensuring uniform error messages /// across the codebase. mod cert_error { + use alloc::sync::Arc; + use core::fmt::{Debug, Display}; use std::io; - use std::sync::Arc; /// Create InvalidData error with formatted message pub fn invalid_data(msg: impl Into) -> io::Error { @@ -67,11 +68,11 @@ mod cert_error { invalid_data(format!("no start line: {context}")) } - pub fn parse_failed(e: impl std::fmt::Display) -> io::Error { + pub fn parse_failed(e: impl Display) -> io::Error { invalid_data(format!("Failed to parse PEM certificate: {e}")) } - pub fn parse_failed_debug(e: impl std::fmt::Debug) -> io::Error { + pub fn parse_failed_debug(e: impl Debug) -> io::Error { invalid_data(format!("Failed to parse PEM certificate: {e:?}")) } @@ -88,7 +89,7 @@ mod cert_error { invalid_data(format!("not enough data: {context}")) } - pub fn parse_failed(e: impl std::fmt::Display) -> io::Error { + pub fn parse_failed(e: impl Display) -> io::Error { invalid_data(format!("Failed to parse DER certificate: {e}")) } } @@ -101,15 +102,15 @@ mod cert_error { invalid_data(format!("No private key found in {context}")) } - pub fn parse_failed(e: impl std::fmt::Display) -> io::Error { + pub fn parse_failed(e: impl Display) -> io::Error { invalid_data(format!("Failed to parse private key: {e}")) } - pub fn parse_encrypted_failed(e: impl std::fmt::Display) -> io::Error { + pub fn parse_encrypted_failed(e: impl Display) -> io::Error { invalid_data(format!("Failed to parse encrypted private key: {e}")) } - pub fn decrypt_failed(e: impl std::fmt::Display) -> io::Error { + pub fn decrypt_failed(e: impl Display) -> io::Error { io::Error::other(format!( "Failed to decrypt private key (wrong password?): {e}", )) @@ -383,7 +384,7 @@ pub fn cert_der_to_dict_helper(vm: &VirtualMachine, cert_der: &[u8]) -> PyResult s.to_string() } else { let value_bytes = attr.attr_value().data; - match std::str::from_utf8(value_bytes) { + match core::str::from_utf8(value_bytes) { Ok(s) => s.to_string(), Err(_) => String::from_utf8_lossy(value_bytes).into_owned(), } @@ -1126,7 +1127,7 @@ pub(super) fn load_cert_chain_from_file( cert_path: &str, key_path: &str, password: Option<&str>, -) -> Result<(Vec>, PrivateKeyDer<'static>), Box> { +) -> Result<(Vec>, PrivateKeyDer<'static>), Box> { // Load certificate file - preserve io::Error for errno let cert_contents = std::fs::read(cert_path)?; @@ -1727,13 +1728,13 @@ fn verify_ip_address( cert: &X509Certificate<'_>, expected_ip: &rustls::pki_types::IpAddr, ) -> Result<(), rustls::Error> { - use std::net::IpAddr; + use core::net::IpAddr; use x509_parser::extensions::GeneralName; // Convert rustls IpAddr to std::net::IpAddr for comparison let expected_std_ip: IpAddr = match expected_ip { - rustls::pki_types::IpAddr::V4(octets) => IpAddr::V4(std::net::Ipv4Addr::from(*octets)), - rustls::pki_types::IpAddr::V6(octets) => IpAddr::V6(std::net::Ipv6Addr::from(*octets)), + rustls::pki_types::IpAddr::V4(octets) => IpAddr::V4(core::net::Ipv4Addr::from(*octets)), + rustls::pki_types::IpAddr::V6(octets) => IpAddr::V6(core::net::Ipv6Addr::from(*octets)), }; // Check Subject Alternative Names for IP addresses @@ -1745,7 +1746,7 @@ fn verify_ip_address( 4 => { // IPv4 if let Ok(octets) = <[u8; 4]>::try_from(*cert_ip_bytes) { - IpAddr::V4(std::net::Ipv4Addr::from(octets)) + IpAddr::V4(core::net::Ipv4Addr::from(octets)) } else { continue; } @@ -1753,7 +1754,7 @@ fn verify_ip_address( 16 => { // IPv6 if let Ok(octets) = <[u8; 16]>::try_from(*cert_ip_bytes) { - IpAddr::V6(std::net::Ipv6Addr::from(octets)) + IpAddr::V6(core::net::Ipv6Addr::from(octets)) } else { continue; } diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index fa12855e242..2168fcfc91f 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -13,6 +13,7 @@ mod ssl_data; use crate::socket::{SelectKind, timeout_error_msg}; use crate::vm::VirtualMachine; +use alloc::sync::Arc; use parking_lot::RwLock as ParkingRwLock; use rustls::RootCertStore; use rustls::client::ClientConfig; @@ -28,7 +29,7 @@ use rustpython_vm::convert::IntoPyException; use rustpython_vm::function::ArgBytesLike; use rustpython_vm::{AsObject, Py, PyObjectRef, PyPayload, PyResult, TryFromObject}; use std::io::Read; -use std::sync::{Arc, Once}; +use std::sync::Once; // Import PySSLSocket from parent module use super::_ssl::PySSLSocket; diff --git a/crates/stdlib/src/syslog.rs b/crates/stdlib/src/syslog.rs index d0ed3f60949..b52d1415692 100644 --- a/crates/stdlib/src/syslog.rs +++ b/crates/stdlib/src/syslog.rs @@ -11,7 +11,8 @@ mod syslog { function::{OptionalArg, OptionalOption}, utils::ToCString, }; - use std::{ffi::CStr, os::raw::c_char}; + use core::ffi::CStr; + use std::os::raw::c_char; #[pyattr] use libc::{ @@ -50,7 +51,7 @@ mod syslog { fn as_ptr(&self) -> *const c_char { match self { Self::Explicit(cstr) => cstr.as_ptr(), - Self::Implicit => std::ptr::null(), + Self::Implicit => core::ptr::null(), } } } diff --git a/crates/stdlib/src/tkinter.rs b/crates/stdlib/src/tkinter.rs index 687458b193b..49dcdc5f84f 100644 --- a/crates/stdlib/src/tkinter.rs +++ b/crates/stdlib/src/tkinter.rs @@ -59,8 +59,8 @@ mod _tkinter { value: *mut tk_sys::Tcl_Obj, } - impl std::fmt::Debug for TclObject { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + impl core::fmt::Debug for TclObject { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "TclObject") } } @@ -107,8 +107,8 @@ mod _tkinter { unsafe impl Send for TkApp {} unsafe impl Sync for TkApp {} - impl std::fmt::Debug for TkApp { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + impl core::fmt::Debug for TkApp { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "TkApp") } } diff --git a/crates/stdlib/src/zlib.rs b/crates/stdlib/src/zlib.rs index 9ca94939f78..632543c5c64 100644 --- a/crates/stdlib/src/zlib.rs +++ b/crates/stdlib/src/zlib.rs @@ -39,7 +39,7 @@ mod zlib { #[pyattr(name = "ZLIB_RUNTIME_VERSION")] #[pyattr] const ZLIB_VERSION: &str = unsafe { - match std::ffi::CStr::from_ptr(libz_sys::zlibVersion()).to_str() { + match core::ffi::CStr::from_ptr(libz_sys::zlibVersion()).to_str() { Ok(s) => s, Err(_) => unreachable!(), } @@ -322,7 +322,7 @@ mod zlib { }; let inner = &mut *self.inner.lock(); - let data = std::mem::replace(&mut inner.unconsumed_tail, vm.ctx.empty_bytes.clone()); + let data = core::mem::replace(&mut inner.unconsumed_tail, vm.ctx.empty_bytes.clone()); let (ret, _) = Self::decompress_inner(inner, &data, length, None, true, vm)?; diff --git a/crates/venvlauncher/src/main.rs b/crates/venvlauncher/src/main.rs index aaf584dfa87..fe147ce7ff3 100644 --- a/crates/venvlauncher/src/main.rs +++ b/crates/venvlauncher/src/main.rs @@ -22,7 +22,7 @@ fn main() -> ExitCode { } } -fn run() -> Result> { +fn run() -> Result> { // 1. Get own executable path let exe_path = env::current_exe()?; let exe_name = exe_path @@ -72,7 +72,7 @@ fn run() -> Result> { } /// Parse the `home=` value from pyvenv.cfg -fn read_home(cfg_path: &Path) -> Result> { +fn read_home(cfg_path: &Path) -> Result> { let content = fs::read_to_string(cfg_path)?; for line in content.lines() { @@ -95,7 +95,7 @@ fn read_home(cfg_path: &Path) -> Result> { } /// Launch the Python process and wait for it to complete -fn launch_process(exe: &Path, args: &[String]) -> Result> { +fn launch_process(exe: &Path, args: &[String]) -> Result> { use std::process::Command; let status = Command::new(exe).args(args).status()?; diff --git a/crates/vm/src/anystr.rs b/crates/vm/src/anystr.rs index ef6d24c100e..79b62a58abf 100644 --- a/crates/vm/src/anystr.rs +++ b/crates/vm/src/anystr.rs @@ -6,6 +6,8 @@ use crate::{ }; use num_traits::{cast::ToPrimitive, sign::Signed}; +use core::ops::Range; + #[derive(FromArgs)] pub struct SplitArgs { #[pyarg(any, default)] @@ -43,7 +45,7 @@ pub struct StartsEndsWithArgs { } impl StartsEndsWithArgs { - pub fn get_value(self, len: usize) -> (PyObjectRef, Option>) { + pub fn get_value(self, len: usize) -> (PyObjectRef, Option>) { let range = if self.start.is_some() || self.end.is_some() { Some(adjust_indices(self.start, self.end, len)) } else { @@ -56,7 +58,7 @@ impl StartsEndsWithArgs { pub fn prepare(self, s: &S, len: usize, substr: F) -> Option<(PyObjectRef, &S)> where S: ?Sized + AnyStr, - F: Fn(&S, std::ops::Range) -> &S, + F: Fn(&S, Range) -> &S, { let (affix, range) = self.get_value(len); let substr = if let Some(range) = range { @@ -83,11 +85,7 @@ fn saturate_to_isize(py_int: PyIntRef) -> isize { } // help get optional string indices -pub fn adjust_indices( - start: Option, - end: Option, - len: usize, -) -> std::ops::Range { +pub fn adjust_indices(start: Option, end: Option, len: usize) -> Range { let mut start = start.map_or(0, saturate_to_isize); let mut end = end.map_or(len as isize, saturate_to_isize); if end > len as isize { @@ -111,7 +109,7 @@ pub trait StringRange { fn is_normal(&self) -> bool; } -impl StringRange for std::ops::Range { +impl StringRange for Range { fn is_normal(&self) -> bool { self.start <= self.end } @@ -144,9 +142,9 @@ pub trait AnyStr { fn to_container(&self) -> Self::Container; fn as_bytes(&self) -> &[u8]; fn elements(&self) -> impl Iterator; - fn get_bytes(&self, range: std::ops::Range) -> &Self; + fn get_bytes(&self, range: Range) -> &Self; // FIXME: get_chars is expensive for str - fn get_chars(&self, range: std::ops::Range) -> &Self; + fn get_chars(&self, range: Range) -> &Self; fn bytes_len(&self) -> usize; // NOTE: str::chars().count() consumes the O(n) time. But pystr::char_len does cache. // So using chars_len directly is too expensive and the below method shouldn't be implemented. @@ -254,7 +252,7 @@ pub trait AnyStr { } #[inline] - fn py_find(&self, needle: &Self, range: std::ops::Range, find: F) -> Option + fn py_find(&self, needle: &Self, range: Range, find: F) -> Option where F: Fn(&Self, &Self) -> Option, { @@ -268,7 +266,7 @@ pub trait AnyStr { } #[inline] - fn py_count(&self, needle: &Self, range: std::ops::Range, count: F) -> usize + fn py_count(&self, needle: &Self, range: Range, count: F) -> usize where F: Fn(&Self, &Self) -> usize, { @@ -283,9 +281,9 @@ pub trait AnyStr { let mut u = Self::Container::with_capacity( (left + right) * fillchar.bytes_len() + self.bytes_len(), ); - u.extend(std::iter::repeat_n(fillchar, left)); + u.extend(core::iter::repeat_n(fillchar, left)); u.push_str(self); - u.extend(std::iter::repeat_n(fillchar, right)); + u.extend(core::iter::repeat_n(fillchar, right)); u } @@ -305,7 +303,7 @@ pub trait AnyStr { fn py_join( &self, - mut iter: impl std::iter::Iterator + TryFromObject>>, + mut iter: impl core::iter::Iterator + TryFromObject>>, ) -> PyResult { let mut joined = if let Some(elem) = iter.next() { elem?.as_ref().unwrap().to_container() @@ -328,7 +326,7 @@ pub trait AnyStr { ) -> PyResult<(Self::Container, bool, Self::Container)> where F: Fn() -> S, - S: std::iter::Iterator, + S: core::iter::Iterator, { if sub.is_empty() { return Err(vm.new_value_error("empty separator")); diff --git a/crates/vm/src/buffer.rs b/crates/vm/src/buffer.rs index 5c67f87521d..3d5e48015ea 100644 --- a/crates/vm/src/buffer.rs +++ b/crates/vm/src/buffer.rs @@ -5,11 +5,13 @@ use crate::{ convert::ToPyObject, function::{ArgBytesLike, ArgIntoBool, ArgIntoFloat}, }; +use alloc::fmt; +use core::{iter::Peekable, mem}; use half::f16; use itertools::Itertools; use malachite_bigint::BigInt; use num_traits::{PrimInt, ToPrimitive}; -use std::{fmt, iter::Peekable, mem, os::raw}; +use std::os::raw; type PackFunc = fn(&VirtualMachine, PyObjectRef, &mut [u8]) -> PyResult<()>; type UnpackFunc = fn(&VirtualMachine, &[u8]) -> PyObjectRef; @@ -545,7 +547,7 @@ macro_rules! make_pack_prim_int { } #[inline] fn unpack_int(data: &[u8]) -> Self { - let mut x = [0; std::mem::size_of::<$T>()]; + let mut x = [0; core::mem::size_of::<$T>()]; x.copy_from_slice(data); E::convert(<$T>::from_ne_bytes(x)) } @@ -681,7 +683,7 @@ fn pack_pascal(vm: &VirtualMachine, arg: PyObjectRef, buf: &mut [u8]) -> PyResul } let b = ArgBytesLike::try_from_object(vm, arg)?; b.with_ref(|data| { - let string_length = std::cmp::min(std::cmp::min(data.len(), 255), buf.len() - 1); + let string_length = core::cmp::min(core::cmp::min(data.len(), 255), buf.len() - 1); buf[0] = string_length as u8; write_string(&mut buf[1..], data); }); @@ -689,7 +691,7 @@ fn pack_pascal(vm: &VirtualMachine, arg: PyObjectRef, buf: &mut [u8]) -> PyResul } fn write_string(buf: &mut [u8], data: &[u8]) { - let len_from_data = std::cmp::min(data.len(), buf.len()); + let len_from_data = core::cmp::min(data.len(), buf.len()); buf[..len_from_data].copy_from_slice(&data[..len_from_data]); for byte in &mut buf[len_from_data..] { *byte = 0 @@ -708,7 +710,7 @@ fn unpack_pascal(vm: &VirtualMachine, data: &[u8]) -> PyObjectRef { return vm.ctx.new_bytes(vec![]).into(); } }; - let len = std::cmp::min(len as usize, data.len()); + let len = core::cmp::min(len as usize, data.len()); vm.ctx.new_bytes(data[..len].to_vec()).into() } diff --git a/crates/vm/src/builtins/bool.rs b/crates/vm/src/builtins/bool.rs index cfd1f136d14..3dabdbae717 100644 --- a/crates/vm/src/builtins/bool.rs +++ b/crates/vm/src/builtins/bool.rs @@ -8,9 +8,9 @@ use crate::{ protocol::PyNumberMethods, types::{AsNumber, Constructor, Representable}, }; +use core::fmt::{Debug, Formatter}; use malachite_bigint::Sign; use num_traits::Zero; -use std::fmt::{Debug, Formatter}; impl ToPyObject for bool { fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef { @@ -90,7 +90,7 @@ impl PyObjectRef { pub struct PyBool(pub PyInt); impl Debug for PyBool { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { let value = !self.0.as_bigint().is_zero(); write!(f, "PyBool({})", value) } diff --git a/crates/vm/src/builtins/builtin_func.rs b/crates/vm/src/builtins/builtin_func.rs index 2b569375b28..422f922df94 100644 --- a/crates/vm/src/builtins/builtin_func.rs +++ b/crates/vm/src/builtins/builtin_func.rs @@ -7,7 +7,7 @@ use crate::{ function::{FuncArgs, PyComparisonValue, PyMethodDef, PyMethodFlags, PyNativeFn}, types::{Callable, Comparable, PyComparisonOp, Representable}, }; -use std::fmt; +use alloc::fmt; // PyCFunctionObject in CPython #[pyclass(name = "builtin_function_or_method", module = false)] @@ -212,7 +212,7 @@ impl Comparable for PyNativeMethod { (None, None) => true, _ => false, }; - let eq = eq && std::ptr::eq(zelf.func.value, other.func.value); + let eq = eq && core::ptr::eq(zelf.func.value, other.func.value); Ok(eq.into()) } else { Ok(PyComparisonValue::NotImplemented) diff --git a/crates/vm/src/builtins/bytearray.rs b/crates/vm/src/builtins/bytearray.rs index c5861befb73..212e4604ec9 100644 --- a/crates/vm/src/builtins/bytearray.rs +++ b/crates/vm/src/builtins/bytearray.rs @@ -37,7 +37,7 @@ use crate::{ }, }; use bstr::ByteSlice; -use std::mem::size_of; +use core::mem::size_of; #[pyclass(module = false, name = "bytearray", unhashable = true)] #[derive(Debug, Default)] @@ -687,7 +687,7 @@ impl Initializer for PyByteArray { fn init(zelf: PyRef, options: Self::Args, vm: &VirtualMachine) -> PyResult<()> { // First unpack bytearray and *then* get a lock to set it. let mut inner = options.get_bytearray_inner(vm)?; - std::mem::swap(&mut *zelf.inner_mut(), &mut inner); + core::mem::swap(&mut *zelf.inner_mut(), &mut inner); Ok(()) } } diff --git a/crates/vm/src/builtins/bytes.rs b/crates/vm/src/builtins/bytes.rs index 0c67cd7bf24..b3feac8ac97 100644 --- a/crates/vm/src/builtins/bytes.rs +++ b/crates/vm/src/builtins/bytes.rs @@ -29,8 +29,8 @@ use crate::{ }, }; use bstr::ByteSlice; +use core::{mem::size_of, ops::Deref}; use std::sync::LazyLock; -use std::{mem::size_of, ops::Deref}; #[pyclass(module = false, name = "bytes")] #[derive(Clone, Debug)] diff --git a/crates/vm/src/builtins/code.rs b/crates/vm/src/builtins/code.rs index e46cc711bb3..b897ef9d311 100644 --- a/crates/vm/src/builtins/code.rs +++ b/crates/vm/src/builtins/code.rs @@ -11,10 +11,11 @@ use crate::{ function::OptionalArg, types::{Constructor, Representable}, }; +use alloc::fmt; +use core::{borrow::Borrow, ops::Deref}; use malachite_bigint::BigInt; use num_traits::Zero; use rustpython_compiler_core::{OneIndexed, bytecode::CodeUnits, bytecode::PyCodeLocationInfoKind}; -use std::{borrow::Borrow, fmt, ops::Deref}; /// State for iterating through code address ranges struct PyCodeAddressRange<'a> { @@ -601,7 +602,7 @@ impl PyCode { pub fn co_code(&self, vm: &VirtualMachine) -> crate::builtins::PyBytesRef { // SAFETY: CodeUnit is #[repr(C)] with size 2, so we can safely transmute to bytes let bytes = unsafe { - std::slice::from_raw_parts( + core::slice::from_raw_parts( self.code.instructions.as_ptr() as *const u8, self.code.instructions.len() * 2, ) diff --git a/crates/vm/src/builtins/complex.rs b/crates/vm/src/builtins/complex.rs index ba74d5e0367..78729b2f5c0 100644 --- a/crates/vm/src/builtins/complex.rs +++ b/crates/vm/src/builtins/complex.rs @@ -10,10 +10,10 @@ use crate::{ stdlib::warnings, types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp, Representable}, }; +use core::num::Wrapping; use num_complex::Complex64; use num_traits::Zero; use rustpython_common::hash; -use std::num::Wrapping; /// Create a complex number from a real part and an optional imaginary part. /// diff --git a/crates/vm/src/builtins/descriptor.rs b/crates/vm/src/builtins/descriptor.rs index 802b81f6d79..aa9da6e2d44 100644 --- a/crates/vm/src/builtins/descriptor.rs +++ b/crates/vm/src/builtins/descriptor.rs @@ -59,8 +59,8 @@ impl PyPayload for PyMethodDescriptor { } } -impl std::fmt::Debug for PyMethodDescriptor { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for PyMethodDescriptor { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "method descriptor for '{}'", self.common.name) } } @@ -218,8 +218,8 @@ impl PyMemberDef { } } -impl std::fmt::Debug for PyMemberDef { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for PyMemberDef { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("PyMemberDef") .field("name", &self.name) .field("kind", &self.kind) @@ -445,8 +445,8 @@ pub enum SlotFunc { NumTernaryRight(PyNumberTernaryFunc), // __rpow__ (swapped first two args) } -impl std::fmt::Debug for SlotFunc { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for SlotFunc { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { SlotFunc::Init(_) => write!(f, "SlotFunc::Init(...)"), SlotFunc::Hash(_) => write!(f, "SlotFunc::Hash(...)"), diff --git a/crates/vm/src/builtins/dict.rs b/crates/vm/src/builtins/dict.rs index 567e18d6419..358685fcdc4 100644 --- a/crates/vm/src/builtins/dict.rs +++ b/crates/vm/src/builtins/dict.rs @@ -23,8 +23,8 @@ use crate::{ }, vm::VirtualMachine, }; +use alloc::fmt; use rustpython_common::lock::PyMutex; -use std::fmt; use std::sync::LazyLock; pub type DictContentType = dict_inner::Dict; @@ -219,7 +219,7 @@ impl PyDict { #[pymethod] fn __sizeof__(&self) -> usize { - std::mem::size_of::() + self.entries.sizeof() + core::mem::size_of::() + self.entries.sizeof() } #[pymethod] @@ -759,7 +759,7 @@ impl ExactSizeIterator for DictIter<'_> { #[pyclass] trait DictView: PyPayload + PyClassDef + Iterable + Representable { - type ReverseIter: PyPayload + std::fmt::Debug; + type ReverseIter: PyPayload + core::fmt::Debug; fn dict(&self) -> &Py; fn item(vm: &VirtualMachine, key: PyObjectRef, value: PyObjectRef) -> PyObjectRef; diff --git a/crates/vm/src/builtins/function.rs b/crates/vm/src/builtins/function.rs index c29e45ddcf6..95d70afcbc7 100644 --- a/crates/vm/src/builtins/function.rs +++ b/crates/vm/src/builtins/function.rs @@ -149,7 +149,7 @@ impl PyFunction { None }; - let arg_pos = |range: std::ops::Range<_>, name: &str| { + let arg_pos = |range: core::ops::Range<_>, name: &str| { code.varnames .iter() .enumerate() @@ -255,7 +255,7 @@ impl PyFunction { } if let Some(defaults) = defaults { - let n = std::cmp::min(nargs, n_expected_args); + let n = core::cmp::min(nargs, n_expected_args); let i = n.saturating_sub(n_required); // We have sufficient defaults, so iterate over the corresponding names and use diff --git a/crates/vm/src/builtins/genericalias.rs b/crates/vm/src/builtins/genericalias.rs index 8a7288980fa..5596aca9da2 100644 --- a/crates/vm/src/builtins/genericalias.rs +++ b/crates/vm/src/builtins/genericalias.rs @@ -16,7 +16,7 @@ use crate::{ PyComparisonOp, Representable, }, }; -use std::fmt; +use alloc::fmt; // attr_exceptions static ATTR_EXCEPTIONS: [&str; 12] = [ diff --git a/crates/vm/src/builtins/getset.rs b/crates/vm/src/builtins/getset.rs index 3fa4667a997..a3f0605a473 100644 --- a/crates/vm/src/builtins/getset.rs +++ b/crates/vm/src/builtins/getset.rs @@ -19,8 +19,8 @@ pub struct PyGetSet { // doc: Option, } -impl std::fmt::Debug for PyGetSet { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for PyGetSet { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!( f, "PyGetSet {{ name: {}, getter: {}, setter: {} }}", @@ -158,7 +158,7 @@ impl Representable for PyGetSet { fn repr_str(zelf: &Py, vm: &VirtualMachine) -> PyResult { let class = unsafe { zelf.class.borrow_static() }; // Special case for object type - if std::ptr::eq(class, vm.ctx.types.object_type) { + if core::ptr::eq(class, vm.ctx.types.object_type) { Ok(format!("", zelf.name)) } else { Ok(format!( diff --git a/crates/vm/src/builtins/int.rs b/crates/vm/src/builtins/int.rs index 37b41e085ad..182333cea51 100644 --- a/crates/vm/src/builtins/int.rs +++ b/crates/vm/src/builtins/int.rs @@ -17,11 +17,11 @@ use crate::{ protocol::{PyNumberMethods, handle_bytes_to_int_err}, types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp, Representable}, }; +use alloc::fmt; +use core::ops::{Neg, Not}; use malachite_bigint::{BigInt, Sign}; use num_integer::Integer; use num_traits::{One, Pow, PrimInt, Signed, ToPrimitive, Zero}; -use std::fmt; -use std::ops::{Neg, Not}; #[pyclass(module = false, name = "int")] #[derive(Debug)] @@ -289,7 +289,7 @@ impl PyInt { I::try_from(self.as_bigint()).map_err(|_| { vm.new_overflow_error(format!( "Python int too large to convert to Rust {}", - std::any::type_name::() + core::any::type_name::() )) }) } @@ -444,7 +444,7 @@ impl PyInt { #[pymethod] fn __sizeof__(&self) -> usize { - std::mem::size_of::() + (((self.value.bits() + 7) & !7) / 8) as usize + core::mem::size_of::() + (((self.value.bits() + 7) & !7) / 8) as usize } #[pymethod] diff --git a/crates/vm/src/builtins/list.rs b/crates/vm/src/builtins/list.rs index 12cab27a750..46145b339cf 100644 --- a/crates/vm/src/builtins/list.rs +++ b/crates/vm/src/builtins/list.rs @@ -20,7 +20,8 @@ use crate::{ utils::collection_repr, vm::VirtualMachine, }; -use std::{fmt, ops::DerefMut}; +use alloc::fmt; +use core::ops::DerefMut; #[pyclass(module = false, name = "list", unhashable = true, traverse)] #[derive(Default)] @@ -172,7 +173,7 @@ impl PyList { #[pymethod] fn clear(&self) { - let _removed = std::mem::take(self.borrow_vec_mut().deref_mut()); + let _removed = core::mem::take(self.borrow_vec_mut().deref_mut()); } #[pymethod] @@ -188,8 +189,8 @@ impl PyList { #[pymethod] fn __sizeof__(&self) -> usize { - std::mem::size_of::() - + self.elements.read().capacity() * std::mem::size_of::() + core::mem::size_of::() + + self.elements.read().capacity() * core::mem::size_of::() } #[pymethod] @@ -324,9 +325,9 @@ impl PyList { // replace list contents with [] for duration of sort. // this prevents keyfunc from messing with the list and makes it easy to // check if it tries to append elements to it. - let mut elements = std::mem::take(self.borrow_vec_mut().deref_mut()); + let mut elements = core::mem::take(self.borrow_vec_mut().deref_mut()); let res = do_sort(vm, &mut elements, options.key, options.reverse); - std::mem::swap(self.borrow_vec_mut().deref_mut(), &mut elements); + core::mem::swap(self.borrow_vec_mut().deref_mut(), &mut elements); res?; if !elements.is_empty() { @@ -375,7 +376,7 @@ impl MutObjectSequenceOp for PyList { inner.get(index).map(|r| r.as_ref()) } - fn do_lock(&self) -> impl std::ops::Deref { + fn do_lock(&self) -> impl core::ops::Deref { self.borrow_vec() } } @@ -397,7 +398,7 @@ impl Initializer for PyList { } else { vec![] }; - std::mem::swap(zelf.borrow_vec_mut().deref_mut(), &mut elements); + core::mem::swap(zelf.borrow_vec_mut().deref_mut(), &mut elements); Ok(()) } } diff --git a/crates/vm/src/builtins/map.rs b/crates/vm/src/builtins/map.rs index f5cee945ece..f83030824f1 100644 --- a/crates/vm/src/builtins/map.rs +++ b/crates/vm/src/builtins/map.rs @@ -42,7 +42,7 @@ impl PyMap { fn __length_hint__(&self, vm: &VirtualMachine) -> PyResult { self.iterators.iter().try_fold(0, |prev, cur| { let cur = cur.as_ref().to_owned().length_hint(0, vm)?; - let max = std::cmp::max(prev, cur); + let max = core::cmp::max(prev, cur); Ok(max) }) } diff --git a/crates/vm/src/builtins/memory.rs b/crates/vm/src/builtins/memory.rs index 4e895f92b7e..ff5df031c42 100644 --- a/crates/vm/src/builtins/memory.rs +++ b/crates/vm/src/builtins/memory.rs @@ -26,11 +26,11 @@ use crate::{ PyComparisonOp, Representable, SelfIter, }, }; +use core::{cmp::Ordering, fmt::Debug, mem::ManuallyDrop, ops::Range}; use crossbeam_utils::atomic::AtomicCell; use itertools::Itertools; use rustpython_common::lock::PyMutex; use std::sync::LazyLock; -use std::{cmp::Ordering, fmt::Debug, mem::ManuallyDrop, ops::Range}; #[derive(FromArgs)] pub struct PyMemoryViewNewArgs { diff --git a/crates/vm/src/builtins/module.rs b/crates/vm/src/builtins/module.rs index f8e42b28e0b..faa6e4813fd 100644 --- a/crates/vm/src/builtins/module.rs +++ b/crates/vm/src/builtins/module.rs @@ -32,8 +32,8 @@ pub struct PyModuleSlots { pub exec: Option, } -impl std::fmt::Debug for PyModuleSlots { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for PyModuleSlots { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("PyModuleSlots") .field("create", &self.create.is_some()) .field("exec", &self.exec.is_some()) diff --git a/crates/vm/src/builtins/object.rs b/crates/vm/src/builtins/object.rs index a61fb1e2971..e73f4b79ae0 100644 --- a/crates/vm/src/builtins/object.rs +++ b/crates/vm/src/builtins/object.rs @@ -218,7 +218,7 @@ fn object_getstate_default(obj: &PyObject, required: bool, vm: &VirtualMachine) // basicsize += std::mem::size_of::(); // } if let Some(ref slot_names) = slot_names { - basicsize += std::mem::size_of::() * slot_names.__len__(); + basicsize += core::mem::size_of::() * slot_names.__len__(); } if obj.class().slots.basicsize > basicsize { return Err( diff --git a/crates/vm/src/builtins/property.rs b/crates/vm/src/builtins/property.rs index 7ea36d39768..3a86867176a 100644 --- a/crates/vm/src/builtins/property.rs +++ b/crates/vm/src/builtins/property.rs @@ -10,7 +10,7 @@ use crate::{ function::{FuncArgs, PySetterValue}, types::{Constructor, GetDescriptor, Initializer}, }; -use std::sync::atomic::{AtomicBool, Ordering}; +use core::sync::atomic::{AtomicBool, Ordering}; #[pyclass(module = false, name = "property", traverse)] #[derive(Debug)] @@ -21,7 +21,7 @@ pub struct PyProperty { doc: PyRwLock>, name: PyRwLock>, #[pytraverse(skip)] - getter_doc: std::sync::atomic::AtomicBool, + getter_doc: core::sync::atomic::AtomicBool, } impl PyPayload for PyProperty { diff --git a/crates/vm/src/builtins/range.rs b/crates/vm/src/builtins/range.rs index 9f79f8efb2d..ab84c977ccd 100644 --- a/crates/vm/src/builtins/range.rs +++ b/crates/vm/src/builtins/range.rs @@ -14,11 +14,11 @@ use crate::{ Representable, SelfIter, }, }; +use core::cmp::max; use crossbeam_utils::atomic::AtomicCell; use malachite_bigint::{BigInt, Sign}; use num_integer::Integer; use num_traits::{One, Signed, ToPrimitive, Zero}; -use std::cmp::max; use std::sync::LazyLock; // Search flag passed to iter_search diff --git a/crates/vm/src/builtins/set.rs b/crates/vm/src/builtins/set.rs index 5582ff3323c..b1236e44e93 100644 --- a/crates/vm/src/builtins/set.rs +++ b/crates/vm/src/builtins/set.rs @@ -23,12 +23,13 @@ use crate::{ utils::collection_repr, vm::VirtualMachine, }; +use alloc::fmt; +use core::ops::Deref; use rustpython_common::{ atomic::{Ordering, PyAtomic, Radium}, hash, }; use std::sync::LazyLock; -use std::{fmt, ops::Deref}; pub type SetContentType = dict_inner::Dict<()>; @@ -50,7 +51,7 @@ impl PySet { fn fold_op( &self, - others: impl std::iter::Iterator, + others: impl core::iter::Iterator, op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult, vm: &VirtualMachine, ) -> PyResult { @@ -68,7 +69,7 @@ impl PySet { Ok(Self { inner: self .inner - .fold_op(std::iter::once(other.into_iterable(vm)?), op, vm)?, + .fold_op(core::iter::once(other.into_iterable(vm)?), op, vm)?, }) } } @@ -111,7 +112,7 @@ impl PyFrozenSet { fn fold_op( &self, - others: impl std::iter::Iterator, + others: impl core::iter::Iterator, op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult, vm: &VirtualMachine, ) -> PyResult { @@ -130,7 +131,7 @@ impl PyFrozenSet { Ok(Self { inner: self .inner - .fold_op(std::iter::once(other.into_iterable(vm)?), op, vm)?, + .fold_op(core::iter::once(other.into_iterable(vm)?), op, vm)?, ..Default::default() }) } @@ -191,7 +192,7 @@ impl PySetInner { fn fold_op( &self, - others: impl std::iter::Iterator, + others: impl core::iter::Iterator, op: fn(&Self, O, &VirtualMachine) -> PyResult, vm: &VirtualMachine, ) -> PyResult { @@ -352,7 +353,7 @@ impl PySetInner { fn update( &self, - others: impl std::iter::Iterator, + others: impl core::iter::Iterator, vm: &VirtualMachine, ) -> PyResult<()> { for iterable in others { @@ -395,7 +396,7 @@ impl PySetInner { fn intersection_update( &self, - others: impl std::iter::Iterator, + others: impl core::iter::Iterator, vm: &VirtualMachine, ) -> PyResult<()> { let temp_inner = self.fold_op(others, Self::intersection, vm)?; @@ -408,7 +409,7 @@ impl PySetInner { fn difference_update( &self, - others: impl std::iter::Iterator, + others: impl core::iter::Iterator, vm: &VirtualMachine, ) -> PyResult<()> { for iterable in others { @@ -422,7 +423,7 @@ impl PySetInner { fn symmetric_difference_update( &self, - others: impl std::iter::Iterator, + others: impl core::iter::Iterator, vm: &VirtualMachine, ) -> PyResult<()> { for iterable in others { @@ -539,7 +540,7 @@ impl PySet { #[pymethod] fn __sizeof__(&self) -> usize { - std::mem::size_of::() + self.inner.sizeof() + core::mem::size_of::() + self.inner.sizeof() } #[pymethod] @@ -731,7 +732,7 @@ impl PySet { #[pymethod] fn __iand__(zelf: PyRef, set: AnySet, vm: &VirtualMachine) -> PyResult> { zelf.inner - .intersection_update(std::iter::once(set.into_iterable(vm)?), vm)?; + .intersection_update(core::iter::once(set.into_iterable(vm)?), vm)?; Ok(zelf) } @@ -978,7 +979,7 @@ impl PyFrozenSet { #[pymethod] fn __sizeof__(&self) -> usize { - std::mem::size_of::() + self.inner.sizeof() + core::mem::size_of::() + self.inner.sizeof() } #[pymethod] @@ -1258,8 +1259,8 @@ impl AnySet { fn into_iterable_iter( self, vm: &VirtualMachine, - ) -> PyResult> { - Ok(std::iter::once(self.into_iterable(vm)?)) + ) -> PyResult> { + Ok(core::iter::once(self.into_iterable(vm)?)) } fn as_inner(&self) -> &PySetInner { diff --git a/crates/vm/src/builtins/str.rs b/crates/vm/src/builtins/str.rs index fa35c1725d4..0357e81b365 100644 --- a/crates/vm/src/builtins/str.rs +++ b/crates/vm/src/builtins/str.rs @@ -24,8 +24,10 @@ use crate::{ PyComparisonOp, Representable, SelfIter, }, }; +use alloc::{borrow::Cow, fmt}; use ascii::{AsciiChar, AsciiStr, AsciiString}; use bstr::ByteSlice; +use core::{char, mem, ops::Range}; use itertools::Itertools; use num_traits::ToPrimitive; use rustpython_common::{ @@ -37,8 +39,7 @@ use rustpython_common::{ str::DeduceStrKind, wtf8::{CodePoint, Wtf8, Wtf8Buf, Wtf8Chunk}, }; -use std::{borrow::Cow, char, fmt, ops::Range}; -use std::{mem, sync::LazyLock}; +use std::sync::LazyLock; use unic_ucd_bidi::BidiClass; use unic_ucd_category::GeneralCategory; use unic_ucd_ident::{is_xid_continue, is_xid_start}; @@ -191,8 +192,8 @@ impl From for PyStr { } } -impl<'a> From> for PyStr { - fn from(s: std::borrow::Cow<'a, str>) -> Self { +impl<'a> From> for PyStr { + fn from(s: alloc::borrow::Cow<'a, str>) -> Self { s.into_owned().into() } } @@ -632,7 +633,7 @@ impl PyStr { #[pymethod] fn __sizeof__(&self) -> usize { - std::mem::size_of::() + self.byte_len() * std::mem::size_of::() + core::mem::size_of::() + self.byte_len() * core::mem::size_of::() } #[pymethod(name = "__rmul__")] @@ -1045,7 +1046,7 @@ impl PyStr { #[pymethod] fn replace(&self, args: ReplaceArgs) -> Wtf8Buf { - use std::cmp::Ordering; + use core::cmp::Ordering; let s = self.as_wtf8(); let ReplaceArgs { old, new, count } = args; @@ -1361,7 +1362,7 @@ impl PyStr { let ch = bigint .as_bigint() .to_u32() - .and_then(std::char::from_u32) + .and_then(core::char::from_u32) .ok_or_else(|| { vm.new_value_error("character mapping must be in range(0x110000)") })?; @@ -1494,7 +1495,7 @@ impl PyRef { } struct CharLenStr<'a>(&'a str, usize); -impl std::ops::Deref for CharLenStr<'_> { +impl core::ops::Deref for CharLenStr<'_> { type Target = str; fn deref(&self) -> &Self::Target { @@ -1705,7 +1706,7 @@ pub struct FindArgs { } impl FindArgs { - fn get_value(self, len: usize) -> (PyStrRef, std::ops::Range) { + fn get_value(self, len: usize) -> (PyStrRef, core::ops::Range) { let range = adjust_indices(self.start, self.end, len); (self.sub, range) } @@ -1945,8 +1946,8 @@ impl PyPayload for PyUtf8Str { ctx.types.str_type } - fn payload_type_id() -> std::any::TypeId { - std::any::TypeId::of::() + fn payload_type_id() -> core::any::TypeId { + core::any::TypeId::of::() } fn validate_downcastable_from(obj: &PyObject) -> bool { @@ -2005,8 +2006,8 @@ impl From for PyUtf8Str { } } -impl<'a> From> for PyUtf8Str { - fn from(s: std::borrow::Cow<'a, str>) -> Self { +impl<'a> From> for PyUtf8Str { + fn from(s: alloc::borrow::Cow<'a, str>) -> Self { s.into_owned().into() } } @@ -2128,11 +2129,11 @@ impl AnyStr for str { Self::chars(self) } - fn get_bytes(&self, range: std::ops::Range) -> &Self { + fn get_bytes(&self, range: core::ops::Range) -> &Self { &self[range] } - fn get_chars(&self, range: std::ops::Range) -> &Self { + fn get_chars(&self, range: core::ops::Range) -> &Self { rustpython_common::str::get_chars(self, range) } @@ -2239,11 +2240,11 @@ impl AnyStr for Wtf8 { self.code_points() } - fn get_bytes(&self, range: std::ops::Range) -> &Self { + fn get_bytes(&self, range: core::ops::Range) -> &Self { &self[range] } - fn get_chars(&self, range: std::ops::Range) -> &Self { + fn get_chars(&self, range: core::ops::Range) -> &Self { rustpython_common::str::get_codepoints(self, range) } @@ -2361,11 +2362,11 @@ impl AnyStr for AsciiStr { self.chars() } - fn get_bytes(&self, range: std::ops::Range) -> &Self { + fn get_bytes(&self, range: core::ops::Range) -> &Self { &self[range] } - fn get_chars(&self, range: std::ops::Range) -> &Self { + fn get_chars(&self, range: core::ops::Range) -> &Self { &self[range] } @@ -2436,8 +2437,8 @@ impl PyStrInterned { } } -impl std::fmt::Display for PyStrInterned { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Display for PyStrInterned { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { self.data.fmt(f) } } diff --git a/crates/vm/src/builtins/traceback.rs b/crates/vm/src/builtins/traceback.rs index 6bf4070c9b5..84493e7125f 100644 --- a/crates/vm/src/builtins/traceback.rs +++ b/crates/vm/src/builtins/traceback.rs @@ -81,7 +81,7 @@ impl Constructor for PyTraceback { impl PyTracebackRef { pub fn iter(&self) -> impl Iterator { - std::iter::successors(Some(self.clone()), |tb| tb.next.lock().clone()) + core::iter::successors(Some(self.clone()), |tb| tb.next.lock().clone()) } } diff --git a/crates/vm/src/builtins/tuple.rs b/crates/vm/src/builtins/tuple.rs index 13335428b35..70e4e20e405 100644 --- a/crates/vm/src/builtins/tuple.rs +++ b/crates/vm/src/builtins/tuple.rs @@ -21,7 +21,8 @@ use crate::{ utils::collection_repr, vm::VirtualMachine, }; -use std::{fmt, sync::LazyLock}; +use alloc::fmt; +use std::sync::LazyLock; #[pyclass(module = false, name = "tuple", traverse)] pub struct PyTuple { @@ -158,7 +159,7 @@ impl AsRef<[R]> for PyTuple { } } -impl std::ops::Deref for PyTuple { +impl core::ops::Deref for PyTuple { type Target = [R]; fn deref(&self) -> &[R] { @@ -166,18 +167,18 @@ impl std::ops::Deref for PyTuple { } } -impl<'a, R> std::iter::IntoIterator for &'a PyTuple { +impl<'a, R> core::iter::IntoIterator for &'a PyTuple { type Item = &'a R; - type IntoIter = std::slice::Iter<'a, R>; + type IntoIter = core::slice::Iter<'a, R>; fn into_iter(self) -> Self::IntoIter { self.iter() } } -impl<'a, R> std::iter::IntoIterator for &'a Py> { +impl<'a, R> core::iter::IntoIterator for &'a Py> { type Item = &'a R; - type IntoIter = std::slice::Iter<'a, R>; + type IntoIter = core::slice::Iter<'a, R>; fn into_iter(self) -> Self::IntoIter { self.iter() @@ -200,7 +201,7 @@ impl PyTuple { } #[inline] - pub fn iter(&self) -> std::slice::Iter<'_, R> { + pub fn iter(&self) -> core::slice::Iter<'_, R> { self.elements.iter() } } @@ -249,15 +250,15 @@ impl PyTuple> { // SAFETY: PyRef has the same layout as PyObjectRef unsafe { let elements: Vec = - std::mem::transmute::>, Vec>(elements); + core::mem::transmute::>, Vec>(elements); let tuple = PyTuple::::new_ref(elements, ctx); - std::mem::transmute::, PyRef>(tuple) + core::mem::transmute::, PyRef>(tuple) } } } #[pyclass( - itemsize = std::mem::size_of::(), + itemsize = core::mem::size_of::(), flags(BASETYPE, SEQUENCE, _MATCH_SELF), with(AsMapping, AsSequence, Hashable, Comparable, Iterable, Constructor, Representable) )] @@ -489,21 +490,21 @@ impl PyRef> { as TransmuteFromObject>::check(vm, elem)?; } // SAFETY: We just verified all elements are of type T - Ok(unsafe { std::mem::transmute::>>>(self) }) + Ok(unsafe { core::mem::transmute::>>>(self) }) } } impl PyRef>> { pub fn into_untyped(self) -> PyRef { // SAFETY: PyTuple> has the same layout as PyTuple - unsafe { std::mem::transmute::>(self) } + unsafe { core::mem::transmute::>(self) } } } impl Py>> { pub fn as_untyped(&self) -> &Py { // SAFETY: PyTuple> has the same layout as PyTuple - unsafe { std::mem::transmute::<&Self, &Py>(self) } + unsafe { core::mem::transmute::<&Self, &Py>(self) } } } diff --git a/crates/vm/src/builtins/type.rs b/crates/vm/src/builtins/type.rs index e807d3f4f8e..3eca5edc478 100644 --- a/crates/vm/src/builtins/type.rs +++ b/crates/vm/src/builtins/type.rs @@ -29,10 +29,11 @@ use crate::{ Representable, SLOT_DEFS, SetAttr, TypeDataRef, TypeDataRefMut, TypeDataSlot, }, }; +use core::{any::Any, borrow::Borrow, ops::Deref, pin::Pin, ptr::NonNull}; use indexmap::{IndexMap, map::Entry}; use itertools::Itertools; use num_traits::ToPrimitive; -use std::{any::Any, borrow::Borrow, collections::HashSet, ops::Deref, pin::Pin, ptr::NonNull}; +use std::collections::HashSet; #[pyclass(module = false, name = "type", traverse = "manual")] pub struct PyType { @@ -118,14 +119,14 @@ unsafe impl Traverse for PyAttributes { } } -impl std::fmt::Display for PyType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - std::fmt::Display::fmt(&self.name(), f) +impl core::fmt::Display for PyType { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + core::fmt::Display::fmt(&self.name(), f) } } -impl std::fmt::Debug for PyType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for PyType { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "[PyType {}]", &self.name()) } } @@ -368,7 +369,7 @@ impl PyType { // Inherit SEQUENCE and MAPPING flags from base class // For static types, we only have a single base - Self::inherit_patma_flags(&mut slots, std::slice::from_ref(&base)); + Self::inherit_patma_flags(&mut slots, core::slice::from_ref(&base)); if slots.basicsize == 0 { slots.basicsize = base.slots.basicsize; @@ -549,7 +550,7 @@ impl PyType { // Gather all members here: let mut attributes = PyAttributes::default(); - for bc in std::iter::once(self) + for bc in core::iter::once(self) .chain(self.mro.read().iter().map(|cls| -> &Self { cls })) .rev() { @@ -667,21 +668,21 @@ impl Py { where F: Fn(&Self) -> R, { - std::iter::once(self) + core::iter::once(self) .chain(self.mro.read().iter().map(|x| x.deref())) .map(f) .collect() } pub fn mro_collect(&self) -> Vec> { - std::iter::once(self) + core::iter::once(self) .chain(self.mro.read().iter().map(|x| x.deref())) .map(|x| x.to_owned()) .collect() } pub fn iter_base_chain(&self) -> impl Iterator { - std::iter::successors(Some(self), |cls| cls.base.as_deref()) + core::iter::successors(Some(self), |cls| cls.base.as_deref()) } pub fn extend_methods(&'static self, method_defs: &'static [PyMethodDef], ctx: &Context) { @@ -846,7 +847,7 @@ impl PyType { // then drop the old value after releasing the lock let _old_qualname = { let mut qualname_guard = heap_type.qualname.write(); - std::mem::replace(&mut *qualname_guard, str_value) + core::mem::replace(&mut *qualname_guard, str_value) }; // old_qualname is dropped here, outside the lock scope @@ -1012,7 +1013,7 @@ impl PyType { // then drop the old value after releasing the lock (similar to CPython's Py_SETREF) let _old_name = { let mut name_guard = self.heaptype_ext.as_ref().unwrap().name.write(); - std::mem::replace(&mut *name_guard, name) + core::mem::replace(&mut *name_guard, name) }; // old_name is dropped here, outside the lock scope diff --git a/crates/vm/src/builtins/union.rs b/crates/vm/src/builtins/union.rs index 83e1316d027..0342442b83d 100644 --- a/crates/vm/src/builtins/union.rs +++ b/crates/vm/src/builtins/union.rs @@ -11,7 +11,7 @@ use crate::{ stdlib::typing::TypeAliasType, types::{AsMapping, AsNumber, Comparable, GetAttr, Hashable, PyComparisonOp, Representable}, }; -use std::fmt; +use alloc::fmt; use std::sync::LazyLock; const CLS_ATTRS: &[&str] = &["__module__"]; diff --git a/crates/vm/src/bytes_inner.rs b/crates/vm/src/bytes_inner.rs index 8593f16fcd9..bb5db442c35 100644 --- a/crates/vm/src/bytes_inner.rs +++ b/crates/vm/src/bytes_inner.rs @@ -151,7 +151,7 @@ impl ByteInnerFindOptions { self, len: usize, vm: &VirtualMachine, - ) -> PyResult<(Vec, std::ops::Range)> { + ) -> PyResult<(Vec, core::ops::Range)> { let sub = match self.sub { Either::A(v) => v.elements.to_vec(), Either::B(int) => vec![int.as_bigint().byte_or(vm)?], @@ -719,7 +719,7 @@ impl PyBytesInner { // len(self)>=1, from="", len(to)>=1, max_count>=1 fn replace_interleave(&self, to: Self, max_count: Option) -> Vec { let place_count = self.elements.len() + 1; - let count = max_count.map_or(place_count, |v| std::cmp::min(v, place_count)) - 1; + let count = max_count.map_or(place_count, |v| core::cmp::min(v, place_count)) - 1; let capacity = self.elements.len() + count * to.len(); let mut result = Vec::with_capacity(capacity); let to_slice = to.elements.as_slice(); @@ -952,7 +952,7 @@ where fn count_substring(haystack: &[u8], needle: &[u8], max_count: Option) -> usize { let substrings = haystack.find_iter(needle); if let Some(max_count) = max_count { - std::cmp::min(substrings.take(max_count).count(), max_count) + core::cmp::min(substrings.take(max_count).count(), max_count) } else { substrings.count() } @@ -1025,11 +1025,11 @@ impl AnyStr for [u8] { self.iter().copied() } - fn get_bytes(&self, range: std::ops::Range) -> &Self { + fn get_bytes(&self, range: core::ops::Range) -> &Self { &self[range] } - fn get_chars(&self, range: std::ops::Range) -> &Self { + fn get_chars(&self, range: core::ops::Range) -> &Self { &self[range] } @@ -1120,7 +1120,7 @@ fn hex_impl(bytes: &[u8], sep: u8, bytes_per_sep: isize) -> String { let len = bytes.len(); let buf = if bytes_per_sep < 0 { - let bytes_per_sep = std::cmp::min(len, (-bytes_per_sep) as usize); + let bytes_per_sep = core::cmp::min(len, (-bytes_per_sep) as usize); let chunks = (len - 1) / bytes_per_sep; let chunked = chunks * bytes_per_sep; let unchunked = len - chunked; @@ -1139,7 +1139,7 @@ fn hex_impl(bytes: &[u8], sep: u8, bytes_per_sep: isize) -> String { hex::encode_to_slice(&bytes[chunked..], &mut buf[j..j + unchunked * 2]).unwrap(); buf } else { - let bytes_per_sep = std::cmp::min(len, bytes_per_sep as usize); + let bytes_per_sep = core::cmp::min(len, bytes_per_sep as usize); let chunks = (len - 1) / bytes_per_sep; let chunked = chunks * bytes_per_sep; let unchunked = len - chunked; diff --git a/crates/vm/src/cformat.rs b/crates/vm/src/cformat.rs index efb3cb2acc9..939b1c7760f 100644 --- a/crates/vm/src/cformat.rs +++ b/crates/vm/src/cformat.rs @@ -342,7 +342,7 @@ pub(crate) fn cformat_bytes( let values = if let Some(tup) = values_obj.downcast_ref::() { tup.as_slice() } else { - std::slice::from_ref(&values_obj) + core::slice::from_ref(&values_obj) }; let mut value_iter = values.iter(); @@ -435,7 +435,7 @@ pub(crate) fn cformat_string( let values = if let Some(tup) = values_obj.downcast_ref::() { tup.as_slice() } else { - std::slice::from_ref(&values_obj) + core::slice::from_ref(&values_obj) }; let mut value_iter = values.iter(); diff --git a/crates/vm/src/class.rs b/crates/vm/src/class.rs index ce41abcc60b..98dc6fd2ed2 100644 --- a/crates/vm/src/class.rs +++ b/crates/vm/src/class.rs @@ -169,7 +169,7 @@ pub trait PyClassImpl: PyClassDef { // Exception: object itself should have __new__ in its dict if let Some(slot_new) = class.slots.new.load() { let object_new = ctx.types.object_type.slots.new.load(); - let is_object_itself = std::ptr::eq(class, ctx.types.object_type); + let is_object_itself = core::ptr::eq(class, ctx.types.object_type); let is_inherited_from_object = !is_object_itself && object_new.is_some_and(|obj_new| slot_new as usize == obj_new as usize); @@ -203,7 +203,7 @@ pub trait PyClassImpl: PyClassDef { Self::extend_class(ctx, unsafe { // typ will be saved in static_cell let r: &Py = &typ; - let r: &'static Py = std::mem::transmute(r); + let r: &'static Py = core::mem::transmute(r); r }); typ diff --git a/crates/vm/src/codecs.rs b/crates/vm/src/codecs.rs index 2edb67b497b..3241dee4981 100644 --- a/crates/vm/src/codecs.rs +++ b/crates/vm/src/codecs.rs @@ -16,12 +16,10 @@ use crate::{ convert::ToPyObject, function::{ArgBytesLike, PyMethodDef}, }; +use alloc::borrow::Cow; +use core::ops::{self, Range}; use once_cell::unsync::OnceCell; -use std::{ - borrow::Cow, - collections::HashMap, - ops::{self, Range}, -}; +use std::collections::HashMap; pub struct CodecsRegistry { inner: PyRwLock, diff --git a/crates/vm/src/convert/try_from.rs b/crates/vm/src/convert/try_from.rs index 4f921e9c5de..ceb7d003e9b 100644 --- a/crates/vm/src/convert/try_from.rs +++ b/crates/vm/src/convert/try_from.rs @@ -122,7 +122,7 @@ impl<'a, T: PyPayload> TryFromBorrowedObject<'a> for &'a Py { } } -impl TryFromObject for std::time::Duration { +impl TryFromObject for core::time::Duration { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { if let Some(float) = obj.downcast_ref::() { let f = float.to_f64(); diff --git a/crates/vm/src/dict_inner.rs b/crates/vm/src/dict_inner.rs index 02e237afb00..d57f8be0fe7 100644 --- a/crates/vm/src/dict_inner.rs +++ b/crates/vm/src/dict_inner.rs @@ -16,8 +16,9 @@ use crate::{ }, object::{Traverse, TraverseFn}, }; +use alloc::fmt; +use core::{mem::size_of, ops::ControlFlow}; use num_traits::ToPrimitive; -use std::{fmt, mem::size_of, ops::ControlFlow}; // HashIndex is intended to be same size with hash::PyHash // but it doesn't mean the values are compatible with actual PyHash value @@ -281,7 +282,7 @@ impl Dict { continue; }; if entry.index == index_index { - let removed = std::mem::replace(&mut entry.value, value); + let removed = core::mem::replace(&mut entry.value, value); // defer dec RC break Some(removed); } else { @@ -357,7 +358,7 @@ impl Dict { inner.used = 0; inner.filled = 0; // defer dec rc - std::mem::take(&mut inner.entries) + core::mem::take(&mut inner.entries) }; } @@ -633,7 +634,7 @@ impl Dict { // returns Err(()) if changed since lookup fn pop_inner(&self, lookup: LookupResult) -> PopInnerResult { - self.pop_inner_if(lookup, |_| Ok::<_, std::convert::Infallible>(true)) + self.pop_inner_if(lookup, |_| Ok::<_, core::convert::Infallible>(true)) .unwrap_or_else(|x| match x {}) } diff --git a/crates/vm/src/exceptions.rs b/crates/vm/src/exceptions.rs index 4e8b572e457..b1a654d58c4 100644 --- a/crates/vm/src/exceptions.rs +++ b/crates/vm/src/exceptions.rs @@ -33,8 +33,8 @@ unsafe impl Traverse for PyBaseException { } } -impl std::fmt::Debug for PyBaseException { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for PyBaseException { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { // TODO: implement more detailed, non-recursive Debug formatter f.write_str("PyBaseException") } @@ -1173,7 +1173,7 @@ pub fn cstring_error(vm: &VirtualMachine) -> PyBaseExceptionRef { vm.new_value_error("embedded null character") } -impl ToPyException for std::ffi::NulError { +impl ToPyException for alloc::ffi::NulError { fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { cstring_error(vm) } @@ -1604,8 +1604,8 @@ pub(super) mod types { } } - impl std::fmt::Debug for PyOSError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + impl core::fmt::Debug for PyOSError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("PyOSError").finish_non_exhaustive() } } diff --git a/crates/vm/src/frame.rs b/crates/vm/src/frame.rs index 3f470e5453f..6acf0e84795 100644 --- a/crates/vm/src/frame.rs +++ b/crates/vm/src/frame.rs @@ -18,13 +18,14 @@ use crate::{ types::PyTypeFlags, vm::{Context, PyMethod}, }; +use alloc::fmt; +use core::iter::zip; +#[cfg(feature = "threading")] +use core::sync::atomic; use indexmap::IndexMap; use itertools::Itertools; use rustpython_common::{boxvec::BoxVec, lock::PyMutex, wtf8::Wtf8Buf}; use rustpython_compiler_core::SourceLocation; -#[cfg(feature = "threading")] -use std::sync::atomic; -use std::{fmt, iter::zip}; #[derive(Clone, Debug)] struct Block { @@ -91,7 +92,7 @@ struct FrameState { #[cfg(feature = "threading")] type Lasti = atomic::AtomicU32; #[cfg(not(feature = "threading"))] -type Lasti = std::cell::Cell; +type Lasti = core::cell::Cell; #[pyclass(module = false, name = "frame")] pub struct Frame { @@ -142,7 +143,7 @@ impl Frame { func_obj: Option, vm: &VirtualMachine, ) -> Self { - let cells_frees = std::iter::repeat_with(|| PyCell::default().into_ref(&vm.ctx)) + let cells_frees = core::iter::repeat_with(|| PyCell::default().into_ref(&vm.ctx)) .take(code.cellvars.len()) .chain(closure.iter().cloned()) .collect(); @@ -189,7 +190,7 @@ impl Frame { let locals = &self.locals; let code = &**self.code; let map = &code.varnames; - let j = std::cmp::min(map.len(), code.varnames.len()); + let j = core::cmp::min(map.len(), code.varnames.len()); if !code.varnames.is_empty() { let fastlocals = self.fastlocals.lock(); for (&k, v) in zip(&map[..j], &**fastlocals) { @@ -2243,14 +2244,14 @@ impl ExecutingFrame<'_> { } })?; let msg = match elements.len().cmp(&(size as usize)) { - std::cmp::Ordering::Equal => { + core::cmp::Ordering::Equal => { self.state.stack.extend(elements.into_iter().rev()); return Ok(None); } - std::cmp::Ordering::Greater => { + core::cmp::Ordering::Greater => { format!("too many values to unpack (expected {size})") } - std::cmp::Ordering::Less => format!( + core::cmp::Ordering::Less => format!( "not enough values to unpack (expected {}, got {})", size, elements.len() @@ -2525,7 +2526,7 @@ impl ExecutingFrame<'_> { #[inline] fn replace_top(&mut self, mut top: PyObjectRef) -> PyObjectRef { let last = self.state.stack.last_mut().unwrap(); - std::mem::swap(&mut top, last); + core::mem::swap(&mut top, last); top } @@ -2561,12 +2562,12 @@ impl fmt::Debug for Frame { if elem.downcastable::() { s.push_str("\n > {frame}"); } else { - std::fmt::write(&mut s, format_args!("\n > {elem:?}")).unwrap(); + core::fmt::write(&mut s, format_args!("\n > {elem:?}")).unwrap(); } s }); let block_str = state.blocks.iter().fold(String::new(), |mut s, elem| { - std::fmt::write(&mut s, format_args!("\n > {elem:?}")).unwrap(); + core::fmt::write(&mut s, format_args!("\n > {elem:?}")).unwrap(); s }); // TODO: fix this up diff --git a/crates/vm/src/function/argument.rs b/crates/vm/src/function/argument.rs index d657ff6be8f..a4877cf4042 100644 --- a/crates/vm/src/function/argument.rs +++ b/crates/vm/src/function/argument.rs @@ -4,9 +4,9 @@ use crate::{ convert::ToPyObject, object::{Traverse, TraverseFn}, }; +use core::ops::RangeInclusive; use indexmap::IndexMap; use itertools::Itertools; -use std::ops::RangeInclusive; pub trait IntoFuncArgs: Sized { fn into_args(self, vm: &VirtualMachine) -> FuncArgs; @@ -100,7 +100,7 @@ impl From for FuncArgs { impl FromArgs for FuncArgs { fn from_args(_vm: &VirtualMachine, args: &mut FuncArgs) -> Result { - Ok(std::mem::take(args)) + Ok(core::mem::take(args)) } } @@ -424,7 +424,7 @@ impl PosArgs { self.0 } - pub fn iter(&self) -> std::slice::Iter<'_, T> { + pub fn iter(&self) -> core::slice::Iter<'_, T> { self.0.iter() } } @@ -469,7 +469,7 @@ where impl IntoIterator for PosArgs { type Item = T; - type IntoIter = std::vec::IntoIter; + type IntoIter = alloc::vec::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() diff --git a/crates/vm/src/function/builtin.rs b/crates/vm/src/function/builtin.rs index 1a91e4344ba..06fd6a44f54 100644 --- a/crates/vm/src/function/builtin.rs +++ b/crates/vm/src/function/builtin.rs @@ -3,7 +3,7 @@ use crate::{ Py, PyPayload, PyRef, PyResult, VirtualMachine, convert::ToPyResult, object::PyThreadingConstraint, }; -use std::marker::PhantomData; +use core::marker::PhantomData; /// A built-in Python function. // PyCFunction in CPython @@ -54,14 +54,14 @@ const fn zst_ref_out_of_thin_air(x: T) -> &'static T { // if T is zero-sized, there's no issue forgetting it - even if it does have a Drop impl, it // would never get called anyway if we consider this semantically a Box::leak(Box::new(x))-type // operation. if T isn't zero-sized, we don't have to worry about it because we'll fail to compile. - std::mem::forget(x); + core::mem::forget(x); const { - if std::mem::size_of::() != 0 { + if core::mem::size_of::() != 0 { panic!("can't use a non-zero-sized type here") } // SAFETY: we just confirmed that T is zero-sized, so we can // pull a value of it out of thin air. - unsafe { std::ptr::NonNull::::dangling().as_ref() } + unsafe { core::ptr::NonNull::::dangling().as_ref() } } } diff --git a/crates/vm/src/function/either.rs b/crates/vm/src/function/either.rs index 8700c6150db..9ee7f028bd2 100644 --- a/crates/vm/src/function/either.rs +++ b/crates/vm/src/function/either.rs @@ -1,7 +1,7 @@ use crate::{ AsObject, PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, convert::ToPyObject, }; -use std::borrow::Borrow; +use core::borrow::Borrow; pub enum Either { A(A), diff --git a/crates/vm/src/function/fspath.rs b/crates/vm/src/function/fspath.rs index 44d41ab7632..7d3a0dcbbd5 100644 --- a/crates/vm/src/function/fspath.rs +++ b/crates/vm/src/function/fspath.rs @@ -5,7 +5,8 @@ use crate::{ function::PyStr, protocol::PyBuffer, }; -use std::{borrow::Cow, ffi::OsStr, path::PathBuf}; +use alloc::borrow::Cow; +use std::{ffi::OsStr, path::PathBuf}; /// Helper to implement os.fspath() #[derive(Clone)] @@ -111,8 +112,8 @@ impl FsPath { Ok(path) } - pub fn to_cstring(&self, vm: &VirtualMachine) -> PyResult { - std::ffi::CString::new(self.as_bytes()).map_err(|e| e.into_pyexception(vm)) + pub fn to_cstring(&self, vm: &VirtualMachine) -> PyResult { + alloc::ffi::CString::new(self.as_bytes()).map_err(|e| e.into_pyexception(vm)) } #[cfg(windows)] diff --git a/crates/vm/src/function/method.rs b/crates/vm/src/function/method.rs index 5e109176c5e..6440fd801fc 100644 --- a/crates/vm/src/function/method.rs +++ b/crates/vm/src/function/method.rs @@ -251,14 +251,14 @@ impl PyMethodDef { } } -impl std::fmt::Debug for PyMethodDef { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for PyMethodDef { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("PyMethodDef") .field("name", &self.name) .field( "func", &(unsafe { - std::mem::transmute::<&dyn PyNativeFn, [usize; 2]>(self.func)[1] as *const u8 + core::mem::transmute::<&dyn PyNativeFn, [usize; 2]>(self.func)[1] as *const u8 }), ) .field("flags", &self.flags) diff --git a/crates/vm/src/function/number.rs b/crates/vm/src/function/number.rs index 7bb37b8f549..fb872cc48fd 100644 --- a/crates/vm/src/function/number.rs +++ b/crates/vm/src/function/number.rs @@ -1,9 +1,9 @@ use super::argument::OptionalArg; use crate::{AsObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, builtins::PyIntRef}; +use core::ops::Deref; use malachite_bigint::BigInt; use num_complex::Complex64; use num_traits::PrimInt; -use std::ops::Deref; /// A Python complex-like object. /// @@ -62,7 +62,7 @@ pub struct ArgIntoFloat { impl ArgIntoFloat { pub fn vec_into_f64(v: Vec) -> Vec { // TODO: Vec::into_raw_parts once stabilized - let mut v = std::mem::ManuallyDrop::new(v); + let mut v = core::mem::ManuallyDrop::new(v); let (p, l, c) = (v.as_mut_ptr(), v.len(), v.capacity()); // SAFETY: IntoPyFloat is repr(transparent) over f64 unsafe { Vec::from_raw_parts(p.cast(), l, c) } diff --git a/crates/vm/src/function/protocol.rs b/crates/vm/src/function/protocol.rs index a87ef339edd..94bdd3027eb 100644 --- a/crates/vm/src/function/protocol.rs +++ b/crates/vm/src/function/protocol.rs @@ -7,7 +7,7 @@ use crate::{ protocol::{PyIter, PyIterIter, PyMapping}, types::GenericMethod, }; -use std::{borrow::Borrow, marker::PhantomData, ops::Deref}; +use core::{borrow::Borrow, marker::PhantomData, ops::Deref}; #[derive(Clone, Traverse)] pub struct ArgCallable { @@ -24,8 +24,8 @@ impl ArgCallable { } } -impl std::fmt::Debug for ArgCallable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for ArgCallable { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("ArgCallable") .field("obj", &self.obj) .field("call", &format!("{:08x}", self.call as usize)) @@ -203,7 +203,7 @@ impl ArgSequence { } } -impl std::ops::Deref for ArgSequence { +impl core::ops::Deref for ArgSequence { type Target = [T]; #[inline(always)] fn deref(&self) -> &[T] { @@ -213,14 +213,14 @@ impl std::ops::Deref for ArgSequence { impl<'a, T> IntoIterator for &'a ArgSequence { type Item = &'a T; - type IntoIter = std::slice::Iter<'a, T>; + type IntoIter = core::slice::Iter<'a, T>; fn into_iter(self) -> Self::IntoIter { self.iter() } } impl IntoIterator for ArgSequence { type Item = T; - type IntoIter = std::vec::IntoIter; + type IntoIter = alloc::vec::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() } diff --git a/crates/vm/src/intern.rs b/crates/vm/src/intern.rs index a5b2a798d53..a50b8871cb9 100644 --- a/crates/vm/src/intern.rs +++ b/crates/vm/src/intern.rs @@ -6,10 +6,8 @@ use crate::{ common::lock::PyRwLock, convert::ToPyObject, }; -use std::{ - borrow::{Borrow, ToOwned}, - ops::Deref, -}; +use alloc::borrow::ToOwned; +use core::{borrow::Borrow, ops::Deref}; #[derive(Debug)] pub struct StringPool { @@ -86,8 +84,8 @@ pub struct CachedPyStrRef { inner: PyRefExact, } -impl std::hash::Hash for CachedPyStrRef { - fn hash(&self, state: &mut H) { +impl core::hash::Hash for CachedPyStrRef { + fn hash(&self, state: &mut H) { self.inner.as_wtf8().hash(state) } } @@ -100,7 +98,7 @@ impl PartialEq for CachedPyStrRef { impl Eq for CachedPyStrRef {} -impl std::borrow::Borrow for CachedPyStrRef { +impl core::borrow::Borrow for CachedPyStrRef { #[inline] fn borrow(&self) -> &Wtf8 { self.as_wtf8() @@ -119,7 +117,7 @@ impl CachedPyStrRef { /// the given cache must be alive while returned reference is alive #[inline] const unsafe fn as_interned_str(&self) -> &'static PyStrInterned { - unsafe { std::mem::transmute_copy(self) } + unsafe { core::mem::transmute_copy(self) } } #[inline] @@ -135,7 +133,7 @@ pub struct PyInterned { impl PyInterned { #[inline] pub fn leak(cache: PyRef) -> &'static Self { - unsafe { std::mem::transmute(cache) } + unsafe { core::mem::transmute(cache) } } #[inline] @@ -163,9 +161,9 @@ impl Borrow for PyInterned { // NOTE: std::hash::Hash of Self and Self::Borrowed *must* be the same // This is ok only because PyObject doesn't implement Hash -impl std::hash::Hash for PyInterned { +impl core::hash::Hash for PyInterned { #[inline(always)] - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.get_id().hash(state) } } @@ -188,15 +186,15 @@ impl Deref for PyInterned { impl PartialEq for PyInterned { #[inline(always)] fn eq(&self, other: &Self) -> bool { - std::ptr::eq(self, other) + core::ptr::eq(self, other) } } impl Eq for PyInterned {} -impl std::fmt::Debug for PyInterned { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - std::fmt::Debug::fmt(&**self, f)?; +impl core::fmt::Debug for PyInterned { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + core::fmt::Debug::fmt(&**self, f)?; write!(f, "@{:p}", self.as_ptr()) } } @@ -308,7 +306,7 @@ impl MaybeInternedString for Py { #[inline(always)] fn as_interned(&self) -> Option<&'static PyStrInterned> { if self.as_object().is_interned() { - Some(unsafe { std::mem::transmute::<&Self, &PyInterned>(self) }) + Some(unsafe { core::mem::transmute::<&Self, &PyInterned>(self) }) } else { None } diff --git a/crates/vm/src/lib.rs b/crates/vm/src/lib.rs index f461c612955..3f0eee278a2 100644 --- a/crates/vm/src/lib.rs +++ b/crates/vm/src/lib.rs @@ -24,6 +24,7 @@ extern crate bitflags; #[macro_use] extern crate log; // extern crate env_logger; +extern crate alloc; #[macro_use] extern crate rustpython_derive; diff --git a/crates/vm/src/macros.rs b/crates/vm/src/macros.rs index 1284c202782..f5c912d89cf 100644 --- a/crates/vm/src/macros.rs +++ b/crates/vm/src/macros.rs @@ -146,8 +146,8 @@ macro_rules! match_class { }; (match ($obj:expr) { ref $binding:ident @ $class:ty => $expr:expr, $($rest:tt)* }) => { match $obj.downcast_ref::<$class>() { - ::std::option::Option::Some($binding) => $expr, - ::std::option::Option::None => $crate::match_class!(match ($obj) { $($rest)* }), + core::option::Option::Some($binding) => $expr, + core::option::Option::None => $crate::match_class!(match ($obj) { $($rest)* }), } }; diff --git a/crates/vm/src/object/core.rs b/crates/vm/src/object/core.rs index d52a33884ce..cee3fac266e 100644 --- a/crates/vm/src/object/core.rs +++ b/crates/vm/src/object/core.rs @@ -31,11 +31,13 @@ use crate::{ object::traverse::{MaybeTraverse, Traverse, TraverseFn}, }; use itertools::Itertools; -use std::{ + +use alloc::fmt; + +use core::{ any::TypeId, borrow::Borrow, cell::UnsafeCell, - fmt, marker::PhantomData, mem::ManuallyDrop, ops::Deref, @@ -82,7 +84,7 @@ pub(super) struct Erased; pub(super) unsafe fn drop_dealloc_obj(x: *mut PyObject) { drop(unsafe { Box::from_raw(x as *mut PyInner) }); } -pub(super) unsafe fn debug_obj( +pub(super) unsafe fn debug_obj( x: &PyObject, f: &mut fmt::Formatter<'_>, ) -> fmt::Result { @@ -114,7 +116,7 @@ pub(super) struct PyInner { pub(super) payload: T, } -pub(crate) const SIZEOF_PYOBJECT_HEAD: usize = std::mem::size_of::>(); +pub(crate) const SIZEOF_PYOBJECT_HEAD: usize = core::mem::size_of::>(); impl fmt::Debug for PyInner { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -376,7 +378,7 @@ impl PyWeak { let node_ptr = unsafe { NonNull::new_unchecked(py_inner as *mut Py) }; // the list doesn't have ownership over its PyRef! we're being dropped // right now so that should be obvious!! - std::mem::forget(unsafe { guard.list.remove(node_ptr) }); + core::mem::forget(unsafe { guard.list.remove(node_ptr) }); guard.ref_count -= 1; if Some(node_ptr) == guard.generic_weakref { guard.generic_weakref = None; @@ -438,11 +440,11 @@ impl InstanceDict { #[inline] pub fn replace(&self, d: PyDictRef) -> PyDictRef { - std::mem::replace(&mut self.d.write(), d) + core::mem::replace(&mut self.d.write(), d) } } -impl PyInner { +impl PyInner { fn new(payload: T, typ: PyTypeRef, dict: Option) -> Box { let member_count = typ.slots.member_count; Box::new(Self { @@ -453,7 +455,7 @@ impl PyInner { dict: dict.map(InstanceDict::new), weak_list: WeakRefList::new(), payload, - slots: std::iter::repeat_with(|| PyRwLock::new(None)) + slots: core::iter::repeat_with(|| PyRwLock::new(None)) .take(member_count) .collect_vec() .into_boxed_slice(), @@ -513,7 +515,7 @@ impl PyObjectRef { #[inline(always)] pub const fn into_raw(self) -> NonNull { let ptr = self.ptr; - std::mem::forget(self); + core::mem::forget(self); ptr } @@ -946,12 +948,12 @@ impl Borrow for Py { } } -impl std::hash::Hash for Py +impl core::hash::Hash for Py where - T: std::hash::Hash + PyPayload, + T: core::hash::Hash + PyPayload, { #[inline] - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.deref().hash(state) } } @@ -978,7 +980,7 @@ where } } -impl fmt::Debug for Py { +impl fmt::Debug for Py { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { (**self).fmt(f) } @@ -1059,12 +1061,12 @@ impl PyRef { pub const fn leak(pyref: Self) -> &'static Py { let ptr = pyref.ptr; - std::mem::forget(pyref); + core::mem::forget(pyref); unsafe { ptr.as_ref() } } } -impl PyRef { +impl PyRef { #[inline(always)] pub fn new_ref(payload: T, typ: crate::builtins::PyTypeRef, dict: Option) -> Self { let inner = Box::into_raw(PyInner::new(payload, typ, dict)); @@ -1074,9 +1076,9 @@ impl PyRef { } } -impl PyRef +impl PyRef where - T::Base: std::fmt::Debug, + T::Base: core::fmt::Debug, { /// Converts this reference to the base type (ownership transfer). /// # Safety @@ -1086,7 +1088,7 @@ where let obj: PyObjectRef = self.into(); match obj.downcast() { Ok(base_ref) => base_ref, - Err(_) => unsafe { std::hint::unreachable_unchecked() }, + Err(_) => unsafe { core::hint::unreachable_unchecked() }, } } #[inline] @@ -1098,7 +1100,7 @@ where let obj: PyObjectRef = self.into(); match obj.downcast::() { Ok(upcast_ref) => upcast_ref, - Err(_) => unsafe { std::hint::unreachable_unchecked() }, + Err(_) => unsafe { core::hint::unreachable_unchecked() }, } } } @@ -1176,12 +1178,12 @@ impl Deref for PyRef { } } -impl std::hash::Hash for PyRef +impl core::hash::Hash for PyRef where - T: std::hash::Hash + PyPayload, + T: core::hash::Hash + PyPayload, { #[inline] - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.deref().hash(state) } } @@ -1230,10 +1232,10 @@ macro_rules! partially_init { $($uninit_field: unreachable!(),)* }}; } - let mut m = ::std::mem::MaybeUninit::<$ty>::uninit(); + let mut m = ::core::mem::MaybeUninit::<$ty>::uninit(); #[allow(unused_unsafe)] unsafe { - $(::std::ptr::write(&mut (*m.as_mut_ptr()).$init_field, $init_value);)* + $(::core::ptr::write(&mut (*m.as_mut_ptr()).$init_field, $init_value);)* } m }}; @@ -1241,7 +1243,7 @@ macro_rules! partially_init { pub(crate) fn init_type_hierarchy() -> (PyTypeRef, PyTypeRef, PyTypeRef) { use crate::{builtins::object, class::PyClassImpl}; - use std::mem::MaybeUninit; + use core::mem::MaybeUninit; // `type` inherits from `object` // and both `type` and `object are instances of `type`. diff --git a/crates/vm/src/object/ext.rs b/crates/vm/src/object/ext.rs index 88f5fdc66d7..c1a5f63f85e 100644 --- a/crates/vm/src/object/ext.rs +++ b/crates/vm/src/object/ext.rs @@ -12,9 +12,10 @@ use crate::{ convert::{IntoPyException, ToPyObject, ToPyResult, TryFromObject}, vm::Context, }; -use std::{ +use alloc::fmt; + +use core::{ borrow::Borrow, - fmt, marker::PhantomData, ops::Deref, ptr::{NonNull, null_mut}, @@ -108,7 +109,7 @@ impl AsRef> for PyExact { } } -impl std::borrow::ToOwned for PyExact { +impl alloc::borrow::ToOwned for PyExact { type Owned = PyRefExact; fn to_owned(&self) -> Self::Owned { @@ -581,7 +582,7 @@ impl ToPyObject for &PyObject { // explicitly implementing `ToPyObject`. impl ToPyObject for T where - T: PyPayload + std::fmt::Debug + Sized, + T: PyPayload + core::fmt::Debug + Sized, { #[inline(always)] fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef { diff --git a/crates/vm/src/object/payload.rs b/crates/vm/src/object/payload.rs index 4b900b7caa1..3a2f42675f7 100644 --- a/crates/vm/src/object/payload.rs +++ b/crates/vm/src/object/payload.rs @@ -27,8 +27,8 @@ pub(crate) fn cold_downcast_type_error( pub trait PyPayload: MaybeTraverse + PyThreadingConstraint + Sized + 'static { #[inline] - fn payload_type_id() -> std::any::TypeId { - std::any::TypeId::of::() + fn payload_type_id() -> core::any::TypeId { + core::any::TypeId::of::() } /// # Safety: this function should only be called if `payload_type_id` matches the type of `obj`. @@ -56,7 +56,7 @@ pub trait PyPayload: MaybeTraverse + PyThreadingConstraint + Sized + 'static { #[inline] fn into_pyobject(self, vm: &VirtualMachine) -> PyObjectRef where - Self: std::fmt::Debug, + Self: core::fmt::Debug, { self.into_ref(&vm.ctx).into() } @@ -64,7 +64,7 @@ pub trait PyPayload: MaybeTraverse + PyThreadingConstraint + Sized + 'static { #[inline] fn _into_ref(self, cls: PyTypeRef, ctx: &Context) -> PyRef where - Self: std::fmt::Debug, + Self: core::fmt::Debug, { let dict = if cls.slots.flags.has_feature(PyTypeFlags::HAS_DICT) { Some(ctx.new_dict()) @@ -77,7 +77,7 @@ pub trait PyPayload: MaybeTraverse + PyThreadingConstraint + Sized + 'static { #[inline] fn into_exact_ref(self, ctx: &Context) -> PyRefExact where - Self: std::fmt::Debug, + Self: core::fmt::Debug, { unsafe { // Self::into_ref() always returns exact typed PyRef @@ -88,7 +88,7 @@ pub trait PyPayload: MaybeTraverse + PyThreadingConstraint + Sized + 'static { #[inline] fn into_ref(self, ctx: &Context) -> PyRef where - Self: std::fmt::Debug, + Self: core::fmt::Debug, { let cls = Self::class(ctx); self._into_ref(cls.to_owned(), ctx) @@ -97,7 +97,7 @@ pub trait PyPayload: MaybeTraverse + PyThreadingConstraint + Sized + 'static { #[inline] fn into_ref_with_type(self, vm: &VirtualMachine, cls: PyTypeRef) -> PyResult> where - Self: std::fmt::Debug, + Self: core::fmt::Debug, { let exact_class = Self::class(&vm.ctx); if cls.fast_issubclass(exact_class) { @@ -138,11 +138,11 @@ pub trait PyPayload: MaybeTraverse + PyThreadingConstraint + Sized + 'static { } pub trait PyObjectPayload: - PyPayload + std::any::Any + std::fmt::Debug + MaybeTraverse + PyThreadingConstraint + 'static + PyPayload + core::any::Any + core::fmt::Debug + MaybeTraverse + PyThreadingConstraint + 'static { } -impl PyObjectPayload for T {} +impl PyObjectPayload for T {} pub trait SlotOffset { fn offset() -> usize; diff --git a/crates/vm/src/object/traverse.rs b/crates/vm/src/object/traverse.rs index 31bee8becea..2ce0db41a5e 100644 --- a/crates/vm/src/object/traverse.rs +++ b/crates/vm/src/object/traverse.rs @@ -1,4 +1,4 @@ -use std::ptr::NonNull; +use core::ptr::NonNull; use rustpython_common::lock::{PyMutex, PyRwLock}; diff --git a/crates/vm/src/object/traverse_object.rs b/crates/vm/src/object/traverse_object.rs index 281b0e56eb5..075ce5b9513 100644 --- a/crates/vm/src/object/traverse_object.rs +++ b/crates/vm/src/object/traverse_object.rs @@ -1,4 +1,4 @@ -use std::fmt; +use alloc::fmt; use crate::{ PyObject, diff --git a/crates/vm/src/ospath.rs b/crates/vm/src/ospath.rs index 77abbee2cd5..25fcafb74c5 100644 --- a/crates/vm/src/ospath.rs +++ b/crates/vm/src/ospath.rs @@ -230,12 +230,12 @@ impl OsPath { self.path.into_encoded_bytes() } - pub fn to_string_lossy(&self) -> std::borrow::Cow<'_, str> { + pub fn to_string_lossy(&self) -> alloc::borrow::Cow<'_, str> { self.path.to_string_lossy() } - pub fn into_cstring(self, vm: &VirtualMachine) -> PyResult { - std::ffi::CString::new(self.into_bytes()).map_err(|err| err.to_pyexception(vm)) + pub fn into_cstring(self, vm: &VirtualMachine) -> PyResult { + alloc::ffi::CString::new(self.into_bytes()).map_err(|err| err.to_pyexception(vm)) } #[cfg(windows)] diff --git a/crates/vm/src/protocol/buffer.rs b/crates/vm/src/protocol/buffer.rs index 88524a9a9ee..0fe4d15458b 100644 --- a/crates/vm/src/protocol/buffer.rs +++ b/crates/vm/src/protocol/buffer.rs @@ -10,8 +10,9 @@ use crate::{ object::PyObjectPayload, sliceable::SequenceIndexOp, }; +use alloc::borrow::Cow; +use core::{fmt::Debug, ops::Range}; use itertools::Itertools; -use std::{borrow::Cow, fmt::Debug, ops::Range}; pub struct BufferMethods { pub obj_bytes: fn(&PyBuffer) -> BorrowedValue<'_, [u8]>, @@ -21,7 +22,7 @@ pub struct BufferMethods { } impl Debug for BufferMethods { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("BufferMethods") .field("obj_bytes", &(self.obj_bytes as usize)) .field("obj_bytes_mut", &(self.obj_bytes_mut as usize)) @@ -134,8 +135,8 @@ impl PyBuffer { pub(crate) unsafe fn drop_without_release(&mut self) { // SAFETY: requirements forwarded from caller unsafe { - std::ptr::drop_in_place(&mut self.obj); - std::ptr::drop_in_place(&mut self.desc); + core::ptr::drop_in_place(&mut self.obj); + core::ptr::drop_in_place(&mut self.desc); } } } @@ -414,7 +415,7 @@ pub struct VecBuffer { #[pyclass(flags(BASETYPE, DISALLOW_INSTANTIATION))] impl VecBuffer { pub fn take(&self) -> Vec { - std::mem::take(&mut self.data.lock()) + core::mem::take(&mut self.data.lock()) } } diff --git a/crates/vm/src/protocol/callable.rs b/crates/vm/src/protocol/callable.rs index 5280e04e928..fa5e48d58ba 100644 --- a/crates/vm/src/protocol/callable.rs +++ b/crates/vm/src/protocol/callable.rs @@ -61,8 +61,8 @@ enum TraceEvent { Return, } -impl std::fmt::Display for TraceEvent { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Display for TraceEvent { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { use TraceEvent::*; match self { Call => write!(f, "call"), diff --git a/crates/vm/src/protocol/iter.rs b/crates/vm/src/protocol/iter.rs index f6146543de9..aa6ab6769cd 100644 --- a/crates/vm/src/protocol/iter.rs +++ b/crates/vm/src/protocol/iter.rs @@ -4,8 +4,8 @@ use crate::{ convert::{ToPyObject, ToPyResult}, object::{Traverse, TraverseFn}, }; -use std::borrow::Borrow; -use std::ops::Deref; +use core::borrow::Borrow; +use core::ops::Deref; /// Iterator Protocol // https://docs.python.org/3/c-api/iter.html @@ -223,7 +223,7 @@ where vm: &'a VirtualMachine, obj: O, // creating PyIter is zero-cost length_hint: Option, - _phantom: std::marker::PhantomData, + _phantom: core::marker::PhantomData, } unsafe impl Traverse for PyIterIter<'_, T, O> @@ -244,7 +244,7 @@ where vm, obj, length_hint, - _phantom: std::marker::PhantomData, + _phantom: core::marker::PhantomData, } } } diff --git a/crates/vm/src/protocol/mapping.rs b/crates/vm/src/protocol/mapping.rs index 43dafeb9238..6c200043e35 100644 --- a/crates/vm/src/protocol/mapping.rs +++ b/crates/vm/src/protocol/mapping.rs @@ -22,8 +22,8 @@ pub struct PyMappingSlots { >, } -impl std::fmt::Debug for PyMappingSlots { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for PyMappingSlots { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_str("PyMappingSlots") } } @@ -56,8 +56,8 @@ pub struct PyMappingMethods { Option, &PyObject, Option, &VirtualMachine) -> PyResult<()>>, } -impl std::fmt::Debug for PyMappingMethods { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for PyMappingMethods { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_str("PyMappingMethods") } } diff --git a/crates/vm/src/protocol/number.rs b/crates/vm/src/protocol/number.rs index c208bf26de8..58891d1d710 100644 --- a/crates/vm/src/protocol/number.rs +++ b/crates/vm/src/protocol/number.rs @@ -1,4 +1,4 @@ -use std::ops::Deref; +use core::ops::Deref; use crossbeam_utils::atomic::AtomicCell; diff --git a/crates/vm/src/protocol/sequence.rs b/crates/vm/src/protocol/sequence.rs index 888ef91565f..cee46a29089 100644 --- a/crates/vm/src/protocol/sequence.rs +++ b/crates/vm/src/protocol/sequence.rs @@ -29,8 +29,8 @@ pub struct PySequenceSlots { pub inplace_repeat: AtomicCell, isize, &VirtualMachine) -> PyResult>>, } -impl std::fmt::Debug for PySequenceSlots { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for PySequenceSlots { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_str("PySequenceSlots") } } @@ -83,8 +83,8 @@ pub struct PySequenceMethods { pub inplace_repeat: Option, isize, &VirtualMachine) -> PyResult>, } -impl std::fmt::Debug for PySequenceMethods { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for PySequenceMethods { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_str("PySequenceMethods") } } diff --git a/crates/vm/src/py_io.rs b/crates/vm/src/py_io.rs index a8063673a70..5649463b30e 100644 --- a/crates/vm/src/py_io.rs +++ b/crates/vm/src/py_io.rs @@ -3,7 +3,9 @@ use crate::{ builtins::{PyBaseExceptionRef, PyBytes, PyStr}, common::ascii, }; -use std::{fmt, io, ops}; +use alloc::fmt; +use core::ops; +use std::io; pub trait Write { type Error; diff --git a/crates/vm/src/py_serde.rs b/crates/vm/src/py_serde.rs index f9a5f4bc060..945068113f1 100644 --- a/crates/vm/src/py_serde.rs +++ b/crates/vm/src/py_serde.rs @@ -130,7 +130,7 @@ impl<'de> DeserializeSeed<'de> for PyObjectDeserializer<'de> { impl<'de> Visitor<'de> for PyObjectDeserializer<'de> { type Value = PyObjectRef; - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { formatter.write_str("a type that can deserialize in Python") } diff --git a/crates/vm/src/readline.rs b/crates/vm/src/readline.rs index 77402dc6839..d62d520ecbd 100644 --- a/crates/vm/src/readline.rs +++ b/crates/vm/src/readline.rs @@ -5,7 +5,7 @@ use std::{io, path::Path}; -type OtherError = Box; +type OtherError = Box; type OtherResult = Result; pub enum ReadlineResult { diff --git a/crates/vm/src/scope.rs b/crates/vm/src/scope.rs index 4f80e9999ec..74392dc9b73 100644 --- a/crates/vm/src/scope.rs +++ b/crates/vm/src/scope.rs @@ -1,5 +1,5 @@ use crate::{VirtualMachine, builtins::PyDictRef, function::ArgMapping}; -use std::fmt; +use alloc::fmt; #[derive(Clone)] pub struct Scope { diff --git a/crates/vm/src/sequence.rs b/crates/vm/src/sequence.rs index 6e03ad1697e..0bc12fd2631 100644 --- a/crates/vm/src/sequence.rs +++ b/crates/vm/src/sequence.rs @@ -6,8 +6,8 @@ use crate::{ types::PyComparisonOp, vm::{MAX_MEMORY_SIZE, VirtualMachine}, }; +use core::ops::{Deref, Range}; use optional::Optioned; -use std::ops::{Deref, Range}; pub trait MutObjectSequenceOp { type Inner: ?Sized; @@ -100,7 +100,7 @@ where fn mul(&self, vm: &VirtualMachine, n: isize) -> PyResult> { let n = vm.check_repeat_or_overflow_error(self.as_ref().len(), n)?; - if n > 1 && std::mem::size_of_val(self.as_ref()) >= MAX_MEMORY_SIZE / n { + if n > 1 && core::mem::size_of_val(self.as_ref()) >= MAX_MEMORY_SIZE / n { return Err(vm.new_memory_error("")); } diff --git a/crates/vm/src/signal.rs b/crates/vm/src/signal.rs index 1074c8e8f11..4a1b84a1521 100644 --- a/crates/vm/src/signal.rs +++ b/crates/vm/src/signal.rs @@ -1,12 +1,8 @@ #![cfg_attr(target_os = "wasi", allow(dead_code))] use crate::{PyResult, VirtualMachine}; -use std::{ - fmt, - sync::{ - atomic::{AtomicBool, Ordering}, - mpsc, - }, -}; +use alloc::fmt; +use core::sync::atomic::{AtomicBool, Ordering}; +use std::sync::mpsc; pub(crate) const NSIG: usize = 64; static ANY_TRIGGERED: AtomicBool = AtomicBool::new(false); diff --git a/crates/vm/src/sliceable.rs b/crates/vm/src/sliceable.rs index 786b66fb36a..e416f5a1b49 100644 --- a/crates/vm/src/sliceable.rs +++ b/crates/vm/src/sliceable.rs @@ -3,9 +3,9 @@ use crate::{ PyObject, PyResult, VirtualMachine, builtins::{int::PyInt, slice::PySlice}, }; +use core::ops::Range; use malachite_bigint::BigInt; use num_traits::{Signed, ToPrimitive}; -use std::ops::Range; pub trait SliceableSequenceMutOp where diff --git a/crates/vm/src/stdlib/ast/elif_else_clause.rs b/crates/vm/src/stdlib/ast/elif_else_clause.rs index 581fc499b8a..e2a8789dd08 100644 --- a/crates/vm/src/stdlib/ast/elif_else_clause.rs +++ b/crates/vm/src/stdlib/ast/elif_else_clause.rs @@ -3,7 +3,7 @@ use rustpython_compiler_core::SourceFile; pub(super) fn ast_to_object( clause: ruff::ElifElseClause, - mut rest: std::vec::IntoIter, + mut rest: alloc::vec::IntoIter, vm: &VirtualMachine, source_file: &SourceFile, ) -> PyObjectRef { diff --git a/crates/vm/src/stdlib/ast/parameter.rs b/crates/vm/src/stdlib/ast/parameter.rs index 87fa736687b..44fcbb2b464 100644 --- a/crates/vm/src/stdlib/ast/parameter.rs +++ b/crates/vm/src/stdlib/ast/parameter.rs @@ -403,7 +403,7 @@ fn merge_keyword_parameter_defaults( kw_only_args: KeywordParameters, defaults: ParameterDefaults, ) -> Vec { - std::iter::zip(kw_only_args.keywords, defaults.defaults) + core::iter::zip(kw_only_args.keywords, defaults.defaults) .map(|(parameter, default)| ruff::ParameterWithDefault { node_index: Default::default(), parameter, diff --git a/crates/vm/src/stdlib/ast/string.rs b/crates/vm/src/stdlib/ast/string.rs index f3df8d99262..ffa5a3a958a 100644 --- a/crates/vm/src/stdlib/ast/string.rs +++ b/crates/vm/src/stdlib/ast/string.rs @@ -12,7 +12,7 @@ fn ruff_fstring_value_into_iter( }); (0..fstring_value.as_slice().len()).map(move |i| { let tmp = fstring_value.iter_mut().nth(i).unwrap(); - std::mem::replace(tmp, default.clone()) + core::mem::replace(tmp, default.clone()) }) } @@ -28,7 +28,7 @@ fn ruff_fstring_element_into_iter( (0..fstring_element.into_iter().len()).map(move |i| { let fstring_element = &mut fstring_element; let tmp = fstring_element.into_iter().nth(i).unwrap(); - std::mem::replace(tmp, default.clone()) + core::mem::replace(tmp, default.clone()) }) } diff --git a/crates/vm/src/stdlib/atexit.rs b/crates/vm/src/stdlib/atexit.rs index b1832b5481d..2286c36f1db 100644 --- a/crates/vm/src/stdlib/atexit.rs +++ b/crates/vm/src/stdlib/atexit.rs @@ -34,7 +34,7 @@ mod atexit { #[pyfunction] pub fn _run_exitfuncs(vm: &VirtualMachine) { - let funcs: Vec<_> = std::mem::take(&mut *vm.state.atexit_funcs.lock()); + let funcs: Vec<_> = core::mem::take(&mut *vm.state.atexit_funcs.lock()); for (func, args) in funcs.into_iter().rev() { if let Err(e) = func.call(args, vm) { let exit = e.fast_isinstance(vm.ctx.exceptions.system_exit); diff --git a/crates/vm/src/stdlib/builtins.rs b/crates/vm/src/stdlib/builtins.rs index 7cd91f8b4b7..c82161fc553 100644 --- a/crates/vm/src/stdlib/builtins.rs +++ b/crates/vm/src/stdlib/builtins.rs @@ -175,7 +175,7 @@ mod builtins { let source = source.borrow_bytes(); // TODO: compiler::compile should probably get bytes - let source = std::str::from_utf8(&source) + let source = core::str::from_utf8(&source) .map_err(|e| vm.new_unicode_decode_error(e.to_string()))?; let flags = args.flags.map_or(Ok(0), |v| v.try_to_primitive(vm))?; @@ -333,7 +333,7 @@ mod builtins { )); } - let source = std::str::from_utf8(source).map_err(|err| { + let source = core::str::from_utf8(source).map_err(|err| { let msg = format!( "(unicode error) 'utf-8' codec can't decode byte 0x{:x?} in position {}: invalid start byte", source[err.valid_up_to()], @@ -605,7 +605,7 @@ mod builtins { } let candidates = match args.args.len().cmp(&1) { - std::cmp::Ordering::Greater => { + core::cmp::Ordering::Greater => { if default.is_some() { return Err(vm.new_type_error(format!( "Cannot specify a default for {func_name}() with multiple positional arguments" @@ -613,8 +613,8 @@ mod builtins { } args.args } - std::cmp::Ordering::Equal => args.args[0].try_to_value(vm)?, - std::cmp::Ordering::Less => { + core::cmp::Ordering::Equal => args.args[0].try_to_value(vm)?, + core::cmp::Ordering::Less => { // zero arguments means type error: return Err( vm.new_type_error(format!("{func_name} expected at least 1 argument, got 0")) diff --git a/crates/vm/src/stdlib/codecs.rs b/crates/vm/src/stdlib/codecs.rs index 821b313090c..1661eef1750 100644 --- a/crates/vm/src/stdlib/codecs.rs +++ b/crates/vm/src/stdlib/codecs.rs @@ -270,7 +270,7 @@ mod _codecs { wide.len() as i32, std::ptr::null_mut(), 0, - std::ptr::null(), + core::ptr::null(), std::ptr::null_mut(), ) }; @@ -291,7 +291,7 @@ mod _codecs { wide.len() as i32, buffer.as_mut_ptr().cast(), size, - std::ptr::null(), + core::ptr::null(), if errors == "strict" { &mut used_default_char } else { @@ -472,7 +472,7 @@ mod _codecs { wide.len() as i32, std::ptr::null_mut(), 0, - std::ptr::null(), + core::ptr::null(), std::ptr::null_mut(), ) }; @@ -493,7 +493,7 @@ mod _codecs { wide.len() as i32, buffer.as_mut_ptr().cast(), size, - std::ptr::null(), + core::ptr::null(), if errors == "strict" { &mut used_default_char } else { diff --git a/crates/vm/src/stdlib/collections.rs b/crates/vm/src/stdlib/collections.rs index eae56968cba..1249fa9315d 100644 --- a/crates/vm/src/stdlib/collections.rs +++ b/crates/vm/src/stdlib/collections.rs @@ -22,9 +22,9 @@ mod _collections { }, utils::collection_repr, }; + use alloc::collections::VecDeque; + use core::cmp::max; use crossbeam_utils::atomic::AtomicCell; - use std::cmp::max; - use std::collections::VecDeque; #[pyattr] #[pyclass(module = "collections", name = "deque", unhashable = true)] @@ -157,7 +157,7 @@ mod _collections { let mut created = VecDeque::from(elements); let mut borrowed = self.borrow_deque_mut(); created.append(&mut borrowed); - std::mem::swap(&mut created, &mut borrowed); + core::mem::swap(&mut created, &mut borrowed); Ok(()) } @@ -426,7 +426,7 @@ mod _collections { inner.get(index).map(|r| r.as_ref()) } - fn do_lock(&self) -> impl std::ops::Deref { + fn do_lock(&self) -> impl core::ops::Deref { self.borrow_deque() } } @@ -484,7 +484,7 @@ mod _collections { // `maxlen` is better to be defined as UnsafeCell in common practice, // but then more type works without any safety benefits let unsafe_maxlen = - &zelf.maxlen as *const _ as *const std::cell::UnsafeCell>; + &zelf.maxlen as *const _ as *const core::cell::UnsafeCell>; *(*unsafe_maxlen).get() = maxlen; } if let Some(elements) = elements { diff --git a/crates/vm/src/stdlib/ctypes.rs b/crates/vm/src/stdlib/ctypes.rs index a9c0636bd12..f3b6dd25aca 100644 --- a/crates/vm/src/stdlib/ctypes.rs +++ b/crates/vm/src/stdlib/ctypes.rs @@ -15,11 +15,11 @@ use crate::{ class::PyClassImpl, types::TypeDataRef, }; -use std::ffi::{ +use core::ffi::{ c_double, c_float, c_int, c_long, c_longlong, c_schar, c_short, c_uchar, c_uint, c_ulong, c_ulonglong, c_ushort, }; -use std::mem; +use core::mem; use widestring::WideChar; pub use array::PyCArray; @@ -387,7 +387,7 @@ pub(crate) mod _ctypes { const RTLD_GLOBAL: i32 = 0; #[pyattr] - const SIZEOF_TIME_T: usize = std::mem::size_of::(); + const SIZEOF_TIME_T: usize = core::mem::size_of::(); #[pyattr] const CTYPES_MAX_ARGCOUNT: usize = 1024; @@ -535,11 +535,11 @@ pub(crate) mod _ctypes { { return Ok(super::get_size(type_str.as_ref())); } - return Ok(std::mem::size_of::()); + return Ok(core::mem::size_of::()); } // Pointer types if type_obj.fast_issubclass(PyCPointer::static_type()) { - return Ok(std::mem::size_of::()); + return Ok(core::mem::size_of::()); } return Err(vm.new_type_error("this type has no size")); } @@ -550,7 +550,7 @@ pub(crate) mod _ctypes { return Ok(cdata.size()); } if obj.fast_isinstance(PyCPointer::static_type()) { - return Ok(std::mem::size_of::()); + return Ok(core::mem::size_of::()); } Err(vm.new_type_error("this type has no size")) @@ -596,13 +596,17 @@ pub(crate) mod _ctypes { } None => { // dlopen(NULL, mode) to get the current process handle (for pythonapi) - let handle = unsafe { libc::dlopen(std::ptr::null(), mode) }; + let handle = unsafe { libc::dlopen(core::ptr::null(), mode) }; if handle.is_null() { let err = unsafe { libc::dlerror() }; let msg = if err.is_null() { "dlopen() error".to_string() } else { - unsafe { std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned() } + unsafe { + core::ffi::CStr::from_ptr(err) + .to_string_lossy() + .into_owned() + } }; return Err(vm.new_os_error(msg)); } @@ -641,7 +645,7 @@ pub(crate) mod _ctypes { name: crate::builtins::PyStrRef, vm: &VirtualMachine, ) -> PyResult { - let symbol_name = std::ffi::CString::new(name.as_str()) + let symbol_name = alloc::ffi::CString::new(name.as_str()) .map_err(|_| vm.new_value_error("symbol name contains null byte"))?; // Clear previous error @@ -652,7 +656,11 @@ pub(crate) mod _ctypes { // Check for error via dlerror first let err = unsafe { libc::dlerror() }; if !err.is_null() { - let msg = unsafe { std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned() }; + let msg = unsafe { + core::ffi::CStr::from_ptr(err) + .to_string_lossy() + .into_owned() + }; return Err(vm.new_os_error(msg)); } @@ -851,7 +859,7 @@ pub(crate) mod _ctypes { } if obj.fast_isinstance(PyCPointer::static_type()) { // Pointer alignment is always pointer size - return Ok(std::mem::align_of::()); + return Ok(core::mem::align_of::()); } if obj.fast_isinstance(PyCUnion::static_type()) { // Calculate alignment from _fields_ @@ -914,7 +922,7 @@ pub(crate) mod _ctypes { #[pyfunction] fn resize(obj: PyObjectRef, size: isize, vm: &VirtualMachine) -> PyResult<()> { - use std::borrow::Cow; + use alloc::borrow::Cow; // 1. Get StgInfo from object's class (validates ctypes instance) let stg_info = obj @@ -1148,8 +1156,8 @@ pub(crate) mod _ctypes { } let raw_ptr = ptr as *mut crate::object::PyObject; unsafe { - let obj = crate::PyObjectRef::from_raw(std::ptr::NonNull::new_unchecked(raw_ptr)); - let obj = std::mem::ManuallyDrop::new(obj); + let obj = crate::PyObjectRef::from_raw(core::ptr::NonNull::new_unchecked(raw_ptr)); + let obj = core::mem::ManuallyDrop::new(obj); Ok((*obj).clone()) } } @@ -1208,12 +1216,12 @@ pub(crate) mod _ctypes { FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, - std::ptr::null(), + core::ptr::null(), error_code, 0, &mut buffer as *mut *mut u16 as *mut u16, 0, - std::ptr::null(), + core::ptr::null(), ) }; @@ -1280,7 +1288,7 @@ pub(crate) mod _ctypes { let vtable = *iunknown; debug_assert!(!vtable.is_null(), "IUnknown vtable is null"); let addref_fn: extern "system" fn(*mut std::ffi::c_void) -> u32 = - std::mem::transmute(*vtable.add(1)); // AddRef is index 1 + core::mem::transmute(*vtable.add(1)); // AddRef is index 1 addref_fn(src_ptr as *mut std::ffi::c_void); } } diff --git a/crates/vm/src/stdlib/ctypes/array.rs b/crates/vm/src/stdlib/ctypes/array.rs index f31c8284d8b..e3444141b2a 100644 --- a/crates/vm/src/stdlib/ctypes/array.rs +++ b/crates/vm/src/stdlib/ctypes/array.rs @@ -618,7 +618,7 @@ impl PyCArray { let ptr_val = usize::from_ne_bytes( ptr_bytes .try_into() - .unwrap_or([0; std::mem::size_of::()]), + .unwrap_or([0; core::mem::size_of::()]), ); if ptr_val == 0 { return Ok(vm.ctx.none()); @@ -630,7 +630,7 @@ impl PyCArray { while *ptr.add(len) != 0 { len += 1; } - let bytes = std::slice::from_raw_parts(ptr, len); + let bytes = core::slice::from_raw_parts(ptr, len); Ok(vm.ctx.new_bytes(bytes.to_vec()).into()) } } @@ -643,7 +643,7 @@ impl PyCArray { let ptr_val = usize::from_ne_bytes( ptr_bytes .try_into() - .unwrap_or([0; std::mem::size_of::()]), + .unwrap_or([0; core::mem::size_of::()]), ); if ptr_val == 0 { return Ok(vm.ctx.none()); @@ -655,10 +655,10 @@ impl PyCArray { let mut pos = 0usize; loop { let code = if WCHAR_SIZE == 2 { - let bytes = std::slice::from_raw_parts(ptr.add(pos), 2); + let bytes = core::slice::from_raw_parts(ptr.add(pos), 2); u16::from_ne_bytes([bytes[0], bytes[1]]) as u32 } else { - let bytes = std::slice::from_raw_parts(ptr.add(pos), 4); + let bytes = core::slice::from_raw_parts(ptr.add(pos), 4); u32::from_ne_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) }; if code == 0 { @@ -1101,7 +1101,7 @@ impl AsBuffer for PyCArray { len: buffer_len, readonly: false, itemsize, - format: std::borrow::Cow::Owned(fmt), + format: alloc::borrow::Cow::Owned(fmt), dim_desc, } } else { @@ -1256,7 +1256,7 @@ fn add_wchar_array_getsets(array_type: &Py, vm: &VirtualMachine) { // Linux/macOS: sizeof(wchar_t) == 4 (UTF-32) /// Size of wchar_t on this platform -pub(super) const WCHAR_SIZE: usize = std::mem::size_of::(); +pub(super) const WCHAR_SIZE: usize = core::mem::size_of::(); /// Read a single wchar_t from bytes (platform-endian) #[inline] diff --git a/crates/vm/src/stdlib/ctypes/base.rs b/crates/vm/src/stdlib/ctypes/base.rs index 0f859b3d10b..1f9eaeef56a 100644 --- a/crates/vm/src/stdlib/ctypes/base.rs +++ b/crates/vm/src/stdlib/ctypes/base.rs @@ -7,15 +7,15 @@ use crate::types::{GetDescriptor, Representable}; use crate::{ AsObject, Py, PyObject, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine, }; +use alloc::borrow::Cow; +use core::ffi::{ + c_double, c_float, c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort, +}; +use core::fmt::Debug; +use core::mem; use crossbeam_utils::atomic::AtomicCell; use num_traits::{Signed, ToPrimitive}; use rustpython_common::lock::PyRwLock; -use std::borrow::Cow; -use std::ffi::{ - c_double, c_float, c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort, -}; -use std::fmt::Debug; -use std::mem; use widestring::WideChar; // StgInfo - Storage information for ctypes types @@ -105,8 +105,8 @@ pub struct StgInfo { unsafe impl Send for StgInfo {} unsafe impl Sync for StgInfo {} -impl std::fmt::Debug for StgInfo { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for StgInfo { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("StgInfo") .field("initialized", &self.initialized) .field("size", &self.size) @@ -227,7 +227,7 @@ impl StgInfo { libffi::middle::Type::structure(self.ffi_field_types.iter().cloned()) } else if self.size <= MAX_FFI_STRUCT_SIZE { // Small struct without field types: use bytes array - libffi::middle::Type::structure(std::iter::repeat_n( + libffi::middle::Type::structure(core::iter::repeat_n( libffi::middle::Type::u8(), self.size, )) @@ -242,9 +242,9 @@ impl StgInfo { libffi::middle::Type::pointer() } else if let Some(ref fmt) = self.format { let elem_type = Self::format_to_ffi_type(fmt); - libffi::middle::Type::structure(std::iter::repeat_n(elem_type, self.length)) + libffi::middle::Type::structure(core::iter::repeat_n(elem_type, self.length)) } else { - libffi::middle::Type::structure(std::iter::repeat_n( + libffi::middle::Type::structure(core::iter::repeat_n( libffi::middle::Type::u8(), self.size, )) @@ -373,10 +373,10 @@ pub(super) static CDATA_BUFFER_METHODS: BufferMethods = BufferMethods { /// Convert Vec to Vec by reinterpreting the memory (same allocation). fn vec_to_bytes(vec: Vec) -> Vec { - let len = vec.len() * std::mem::size_of::(); - let cap = vec.capacity() * std::mem::size_of::(); + let len = vec.len() * core::mem::size_of::(); + let cap = vec.capacity() * core::mem::size_of::(); let ptr = vec.as_ptr() as *mut u8; - std::mem::forget(vec); + core::mem::forget(vec); unsafe { Vec::from_raw_parts(ptr, len, cap) } } @@ -406,7 +406,7 @@ pub(super) fn str_to_wchar_bytes(s: &str, vm: &VirtualMachine) -> (PyObjectRef, let wchars: Vec = s .chars() .map(|c| c as libc::wchar_t) - .chain(std::iter::once(0)) + .chain(core::iter::once(0)) .collect(); let ptr = wchars.as_ptr() as usize; let bytes = vec_to_bytes(wchars); @@ -486,7 +486,7 @@ impl PyCData { pub unsafe fn at_address(ptr: *const u8, size: usize) -> Self { // = PyCData_AtAddress // SAFETY: Caller must ensure ptr is valid for the lifetime of returned PyCData - let slice: &'static [u8] = unsafe { std::slice::from_raw_parts(ptr, size) }; + let slice: &'static [u8] = unsafe { core::slice::from_raw_parts(ptr, size) }; PyCData { buffer: PyRwLock::new(Cow::Borrowed(slice)), base: PyRwLock::new(None), @@ -534,7 +534,7 @@ impl PyCData { ) -> Self { // = PyCData_FromBaseObj // SAFETY: ptr points into base_obj's buffer, kept alive via base reference - let slice: &'static [u8] = unsafe { std::slice::from_raw_parts(ptr, size) }; + let slice: &'static [u8] = unsafe { core::slice::from_raw_parts(ptr, size) }; PyCData { buffer: PyRwLock::new(Cow::Borrowed(slice)), base: PyRwLock::new(Some(base_obj)), @@ -561,7 +561,7 @@ impl PyCData { vm: &VirtualMachine, ) -> Self { // SAFETY: Caller must ensure ptr is valid for the lifetime of source - let slice: &'static [u8] = unsafe { std::slice::from_raw_parts(ptr, size) }; + let slice: &'static [u8] = unsafe { core::slice::from_raw_parts(ptr, size) }; // Python stores the reference in a dict with key "-1" (unique_key pattern) let objects_dict = vm.ctx.new_dict(); @@ -707,7 +707,7 @@ impl PyCData { // (e.g., from from_address pointing to a ctypes buffer) unsafe { let ptr = slice.as_ptr() as *mut u8; - std::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr.add(offset), bytes.len()); + core::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr.add(offset), bytes.len()); } } Cow::Owned(_) => { @@ -893,7 +893,7 @@ impl PyCData { if let Some(bytes_val) = value.downcast_ref::() { let src = bytes_val.as_bytes(); let to_copy = PyCField::bytes_for_char_array(src); - let copy_len = std::cmp::min(to_copy.len(), size); + let copy_len = core::cmp::min(to_copy.len(), size); self.write_bytes_at_offset(offset, &to_copy[..copy_len]); self.keep_ref(index, value, vm)?; return Ok(()); @@ -936,7 +936,7 @@ impl PyCData { array_buffer.as_ptr() as usize }; let addr_bytes = buffer_addr.to_ne_bytes(); - let len = std::cmp::min(addr_bytes.len(), size); + let len = core::cmp::min(addr_bytes.len(), size); self.write_bytes_at_offset(offset, &addr_bytes[..len]); self.keep_ref(index, value, vm)?; return Ok(()); @@ -1364,7 +1364,7 @@ impl PyCField { if let Some(bytes) = value.downcast_ref::() { let src = bytes.as_bytes(); let mut result = vec![0u8; size]; - let len = std::cmp::min(src.len(), size); + let len = core::cmp::min(src.len(), size); result[..len].copy_from_slice(&src[..len]); Ok(result) } @@ -1372,7 +1372,7 @@ impl PyCField { else if let Some(cdata) = value.downcast_ref::() { let buffer = cdata.buffer.read(); let mut result = vec![0u8; size]; - let len = std::cmp::min(buffer.len(), size); + let len = core::cmp::min(buffer.len(), size); result[..len].copy_from_slice(&buffer[..len]); Ok(result) } @@ -1473,7 +1473,7 @@ impl PyCField { let (converted, ptr) = ensure_z_null_terminated(bytes, vm); let mut result = vec![0u8; size]; let addr_bytes = ptr.to_ne_bytes(); - let len = std::cmp::min(addr_bytes.len(), size); + let len = core::cmp::min(addr_bytes.len(), size); result[..len].copy_from_slice(&addr_bytes[..len]); return Ok((result, Some(converted))); } @@ -1482,7 +1482,7 @@ impl PyCField { let v = int_val.as_bigint().to_usize().unwrap_or(0); let mut result = vec![0u8; size]; let bytes = v.to_ne_bytes(); - let len = std::cmp::min(bytes.len(), size); + let len = core::cmp::min(bytes.len(), size); result[..len].copy_from_slice(&bytes[..len]); return Ok((result, None)); } @@ -1498,7 +1498,7 @@ impl PyCField { let (holder, ptr) = str_to_wchar_bytes(s.as_str(), vm); let mut result = vec![0u8; size]; let addr_bytes = ptr.to_ne_bytes(); - let len = std::cmp::min(addr_bytes.len(), size); + let len = core::cmp::min(addr_bytes.len(), size); result[..len].copy_from_slice(&addr_bytes[..len]); return Ok((result, Some(holder))); } @@ -1507,7 +1507,7 @@ impl PyCField { let v = int_val.as_bigint().to_usize().unwrap_or(0); let mut result = vec![0u8; size]; let bytes = v.to_ne_bytes(); - let len = std::cmp::min(bytes.len(), size); + let len = core::cmp::min(bytes.len(), size); result[..len].copy_from_slice(&bytes[..len]); return Ok((result, None)); } @@ -1523,7 +1523,7 @@ impl PyCField { let v = int_val.as_bigint().to_usize().unwrap_or(0); let mut result = vec![0u8; size]; let bytes = v.to_ne_bytes(); - let len = std::cmp::min(bytes.len(), size); + let len = core::cmp::min(bytes.len(), size); result[..len].copy_from_slice(&bytes[..len]); return Ok((result, None)); } @@ -2078,7 +2078,7 @@ pub(super) fn bytes_to_pyobject( if ptr == 0 { return Ok(vm.ctx.none()); } - let c_str = unsafe { std::ffi::CStr::from_ptr(ptr as _) }; + let c_str = unsafe { core::ffi::CStr::from_ptr(ptr as _) }; Ok(vm.ctx.new_bytes(c_str.to_bytes().to_vec()).into()) } "Z" => { @@ -2089,7 +2089,7 @@ pub(super) fn bytes_to_pyobject( } let len = unsafe { libc::wcslen(ptr as *const libc::wchar_t) }; let wchars = - unsafe { std::slice::from_raw_parts(ptr as *const libc::wchar_t, len) }; + unsafe { core::slice::from_raw_parts(ptr as *const libc::wchar_t, len) }; let s: String = wchars .iter() .filter_map(|&c| char::from_u32(c as u32)) @@ -2149,7 +2149,7 @@ pub(super) fn get_usize_attr( /// Read a pointer value from buffer #[inline] pub(super) fn read_ptr_from_buffer(buffer: &[u8]) -> usize { - const PTR_SIZE: usize = std::mem::size_of::(); + const PTR_SIZE: usize = core::mem::size_of::(); if buffer.len() >= PTR_SIZE { usize::from_ne_bytes(buffer[..PTR_SIZE].try_into().unwrap()) } else { @@ -2242,7 +2242,7 @@ pub(super) fn get_field_size(field_type: &PyObject, vm: &VirtualMachine) -> PyRe return Ok(s); } - Ok(std::mem::size_of::()) + Ok(core::mem::size_of::()) } /// Get the alignment of a ctypes field type diff --git a/crates/vm/src/stdlib/ctypes/function.rs b/crates/vm/src/stdlib/ctypes/function.rs index 55a42f0ba15..5906bc91cd4 100644 --- a/crates/vm/src/stdlib/ctypes/function.rs +++ b/crates/vm/src/stdlib/ctypes/function.rs @@ -16,6 +16,9 @@ use crate::{ types::{AsBuffer, Callable, Constructor, Initializer, Representable}, vm::thread::with_current_vm, }; +use alloc::borrow::Cow; +use core::ffi::c_void; +use core::fmt::Debug; use libffi::{ low, middle::{Arg, Cif, Closure, CodePtr, Type}, @@ -23,9 +26,6 @@ use libffi::{ use libloading::Symbol; use num_traits::{Signed, ToPrimitive}; use rustpython_common::lock::PyRwLock; -use std::borrow::Cow; -use std::ffi::c_void; -use std::fmt::Debug; // Internal function addresses for special ctypes functions pub(super) const INTERNAL_CAST_ADDR: usize = 1; @@ -37,7 +37,7 @@ std::thread_local! { /// Thread-local storage for ctypes errno /// This is separate from the system errno - ctypes swaps them during FFI calls /// when use_errno=True is specified. - static CTYPES_LOCAL_ERRNO: std::cell::Cell = const { std::cell::Cell::new(0) }; + static CTYPES_LOCAL_ERRNO: core::cell::Cell = const { core::cell::Cell::new(0) }; } /// Get ctypes thread-local errno value @@ -79,7 +79,7 @@ where #[cfg(windows)] std::thread_local! { /// Thread-local storage for ctypes last_error (Windows only) - static CTYPES_LOCAL_LAST_ERROR: std::cell::Cell = const { std::cell::Cell::new(0) }; + static CTYPES_LOCAL_LAST_ERROR: core::cell::Cell = const { core::cell::Cell::new(0) }; } #[cfg(windows)] @@ -135,14 +135,14 @@ fn ffi_type_from_tag(tag: u8) -> Type { b'i' => Type::i32(), b'I' => Type::u32(), b'l' => { - if std::mem::size_of::() == 8 { + if core::mem::size_of::() == 8 { Type::i64() } else { Type::i32() } } b'L' => { - if std::mem::size_of::() == 8 { + if core::mem::size_of::() == 8 { Type::u64() } else { Type::u32() @@ -154,7 +154,7 @@ fn ffi_type_from_tag(tag: u8) -> Type { b'd' | b'g' => Type::f64(), b'?' => Type::u8(), b'u' => { - if std::mem::size_of::() == 2 { + if core::mem::size_of::() == 2 { Type::u16() } else { Type::u32() @@ -207,7 +207,7 @@ fn convert_to_pointer(value: &PyObject, vm: &VirtualMachine) -> PyResult value from buffer if let Some(simple) = value.downcast_ref::() { let buffer = simple.0.buffer.read(); - if buffer.len() >= std::mem::size_of::() { + if buffer.len() >= core::mem::size_of::() { let addr = super::base::read_ptr_from_buffer(&buffer); return Ok(FfiArgValue::Pointer(addr)); } @@ -283,7 +283,7 @@ fn conv_param(value: &PyObject, vm: &VirtualMachine) -> PyResult { let wide: Vec = s .as_str() .encode_utf16() - .chain(std::iter::once(0)) + .chain(core::iter::once(0)) .collect(); let wide_bytes: Vec = wide.iter().flat_map(|&x| x.to_ne_bytes()).collect(); let keep = vm.ctx.new_bytes(wide_bytes); @@ -499,7 +499,7 @@ impl Initializer for PyCFuncPtrType { new_type.check_not_initialized(vm)?; - let ptr_size = std::mem::size_of::(); + let ptr_size = core::mem::size_of::(); let mut stg_info = StgInfo::new(ptr_size, ptr_size); stg_info.format = Some("X{}".to_string()); stg_info.length = 1; @@ -552,7 +552,7 @@ pub(super) struct PyCFuncPtr { } impl Debug for PyCFuncPtr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("PyCFuncPtr") .field("func_ptr", &self.get_func_ptr()) .finish() @@ -567,9 +567,9 @@ fn extract_ptr_from_arg(arg: &PyObject, vm: &VirtualMachine) -> PyResult } if let Some(simple) = arg.downcast_ref::() { let buffer = simple.0.buffer.read(); - if buffer.len() >= std::mem::size_of::() { + if buffer.len() >= core::mem::size_of::() { return Ok(usize::from_ne_bytes( - buffer[..std::mem::size_of::()].try_into().unwrap(), + buffer[..core::mem::size_of::()].try_into().unwrap(), )); } } @@ -612,7 +612,7 @@ fn string_at_impl(ptr: usize, size: isize, vm: &VirtualMachine) -> PyResult { } size_usize }; - let bytes = unsafe { std::slice::from_raw_parts(ptr, len) }; + let bytes = unsafe { core::slice::from_raw_parts(ptr, len) }; Ok(vm.ctx.new_bytes(bytes.to_vec()).into()) } @@ -627,12 +627,12 @@ fn wstring_at_impl(ptr: usize, size: isize, vm: &VirtualMachine) -> PyResult { } else { // Overflow check for huge size values let size_usize = size as usize; - if size_usize > isize::MAX as usize / std::mem::size_of::() { + if size_usize > isize::MAX as usize / core::mem::size_of::() { return Err(vm.new_overflow_error("string too long")); } size_usize }; - let wchars = unsafe { std::slice::from_raw_parts(w_ptr, len) }; + let wchars = unsafe { core::slice::from_raw_parts(w_ptr, len) }; // Windows: wchar_t = u16 (UTF-16) -> use Wtf8Buf::from_wide // macOS/Linux: wchar_t = i32 (UTF-32) -> convert via char::from_u32 @@ -815,7 +815,7 @@ impl Constructor for PyCFuncPtr { // 3. Tuple argument: (name, dll) form // 4. Callable: callback creation - let ptr_size = std::mem::size_of::(); + let ptr_size = core::mem::size_of::(); if args.args.is_empty() { return PyCFuncPtr { @@ -1513,11 +1513,11 @@ fn convert_raw_result( RawResult::Void => return None, RawResult::Pointer(ptr) => { let bytes = ptr.to_ne_bytes(); - (bytes.to_vec(), std::mem::size_of::()) + (bytes.to_vec(), core::mem::size_of::()) } RawResult::Value(val) => { let bytes = val.to_ne_bytes(); - (bytes.to_vec(), std::mem::size_of::()) + (bytes.to_vec(), core::mem::size_of::()) } }; @@ -1702,7 +1702,7 @@ impl Callable for PyCFuncPtr { None => { debug_assert!(false, "NULL function pointer"); // In release mode, this will crash - CodePtr(std::ptr::null_mut()) + CodePtr(core::ptr::null_mut()) } }; @@ -1758,7 +1758,7 @@ impl AsBuffer for PyCFuncPtr { stg_info.size, ) } else { - (Cow::Borrowed("X{}"), std::mem::size_of::()) + (Cow::Borrowed("X{}"), core::mem::size_of::()) }; let desc = BufferDescriptor { len: itemsize, @@ -1902,7 +1902,7 @@ fn ffi_to_python(ty: &Py, ptr: *const c_void, vm: &VirtualMachine) -> Py if cstr_ptr.is_null() { vm.ctx.none() } else { - let cstr = std::ffi::CStr::from_ptr(cstr_ptr); + let cstr = core::ffi::CStr::from_ptr(cstr_ptr); vm.ctx.new_bytes(cstr.to_bytes().to_vec()).into() } } @@ -1916,7 +1916,7 @@ fn ffi_to_python(ty: &Py, ptr: *const c_void, vm: &VirtualMachine) -> Py while *wstr_ptr.add(len) != 0 { len += 1; } - let slice = std::slice::from_raw_parts(wstr_ptr, len); + let slice = core::slice::from_raw_parts(wstr_ptr, len); // Windows: wchar_t = u16 (UTF-16) -> use Wtf8Buf::from_wide // Unix: wchar_t = i32 (UTF-32) -> convert via char::from_u32 #[cfg(windows)] @@ -2113,7 +2113,7 @@ pub(super) struct PyCThunk { } impl Debug for PyCThunk { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("PyCThunk") .field("callable", &self.callable) .finish() diff --git a/crates/vm/src/stdlib/ctypes/library.rs b/crates/vm/src/stdlib/ctypes/library.rs index 7512ce29d8a..35ccb433845 100644 --- a/crates/vm/src/stdlib/ctypes/library.rs +++ b/crates/vm/src/stdlib/ctypes/library.rs @@ -1,9 +1,9 @@ use crate::VirtualMachine; +use alloc::fmt; use libloading::Library; use rustpython_common::lock::{PyMutex, PyRwLock}; use std::collections::HashMap; use std::ffi::OsStr; -use std::fmt; #[cfg(unix)] use libloading::os::unix::Library as UnixLibrary; @@ -54,7 +54,7 @@ impl SharedLibrary { // On Windows: HMODULE (*mut c_void) // On Unix: *mut c_void from dlopen // We use transmute_copy to read the handle without consuming the Library - unsafe { std::mem::transmute_copy::(l) } + unsafe { core::mem::transmute_copy::(l) } } else { 0 } diff --git a/crates/vm/src/stdlib/ctypes/pointer.rs b/crates/vm/src/stdlib/ctypes/pointer.rs index ae97b741b3c..f564fd1965c 100644 --- a/crates/vm/src/stdlib/ctypes/pointer.rs +++ b/crates/vm/src/stdlib/ctypes/pointer.rs @@ -8,8 +8,8 @@ use crate::{ class::StaticType, function::{FuncArgs, OptionalArg}, }; +use alloc::borrow::Cow; use num_traits::ToPrimitive; -use std::borrow::Cow; #[pyclass(name = "PyCPointerType", base = PyType, module = "_ctypes")] #[derive(Debug)] @@ -37,7 +37,7 @@ impl Initializer for PyCPointerType { .and_then(|obj| obj.downcast::().ok()); // Initialize StgInfo for pointer type - let pointer_size = std::mem::size_of::(); + let pointer_size = core::mem::size_of::(); let mut stg_info = StgInfo::new(pointer_size, pointer_size); stg_info.proto = proto; stg_info.paramfunc = super::base::ParamFunc::Pointer; @@ -232,7 +232,7 @@ impl Constructor for PyCPointer { // Create a new PyCPointer instance with NULL pointer (all zeros) // Initial contents is set via __init__ if provided - let cdata = PyCData::from_bytes(vec![0u8; std::mem::size_of::()], None); + let cdata = PyCData::from_bytes(vec![0u8; core::mem::size_of::()], None); // pointer instance has b_length set to 2 (for index 0 and 1) cdata.length.store(2); PyCPointer(cdata) @@ -299,7 +299,7 @@ impl PyCPointer { let proto_type = stg_info.proto(); let element_size = proto_type .stg_info_opt() - .map_or(std::mem::size_of::(), |info| info.size); + .map_or(core::mem::size_of::(), |info| info.size); // Create instance that references the memory directly // PyCData.into_ref_with_type works for all ctypes (simple, structure, union, array, pointer) @@ -383,7 +383,7 @@ impl PyCPointer { let proto_type = stg_info.proto(); let element_size = proto_type .stg_info_opt() - .map_or(std::mem::size_of::(), |info| info.size); + .map_or(core::mem::size_of::(), |info| info.size); // offset = index * iteminfo->size let offset = index * element_size as isize; @@ -468,7 +468,7 @@ impl PyCPointer { let element_size = if let Some(ref proto_type) = stg_info.proto { proto_type.stg_info_opt().expect("proto has StgInfo").size } else { - std::mem::size_of::() + core::mem::size_of::() }; let type_code = stg_info .proto @@ -489,7 +489,7 @@ impl PyCPointer { // Optimized contiguous copy let start_addr = (ptr_value as isize + start * element_size as isize) as *const u8; unsafe { - result.extend_from_slice(std::slice::from_raw_parts(start_addr, len)); + result.extend_from_slice(core::slice::from_raw_parts(start_addr, len)); } } else { let mut cur = start; @@ -510,7 +510,7 @@ impl PyCPointer { return Ok(vm.ctx.new_str("").into()); } let mut result = String::with_capacity(len); - let wchar_size = std::mem::size_of::(); + let wchar_size = core::mem::size_of::(); let mut cur = start; for _ in 0..len { let addr = (ptr_value as isize + cur * wchar_size as isize) as *const libc::wchar_t; @@ -578,7 +578,7 @@ impl PyCPointer { let element_size = proto_type .stg_info_opt() - .map_or(std::mem::size_of::(), |info| info.size); + .map_or(core::mem::size_of::(), |info| info.size); // Calculate address let offset = index * element_size as isize; @@ -595,7 +595,7 @@ impl PyCPointer { let copy_len = src_buffer.len().min(element_size); unsafe { let dest_ptr = addr as *mut u8; - std::ptr::copy_nonoverlapping(src_buffer.as_ptr(), dest_ptr, copy_len); + core::ptr::copy_nonoverlapping(src_buffer.as_ptr(), dest_ptr, copy_len); } } else { // Handle z/Z specially to store converted value @@ -641,43 +641,43 @@ impl PyCPointer { // Multi-byte types need read_unaligned for safety on strict-alignment architectures Some("h") => Ok(vm .ctx - .new_int(std::ptr::read_unaligned(ptr as *const i16) as i32) + .new_int(core::ptr::read_unaligned(ptr as *const i16) as i32) .into()), Some("H") => Ok(vm .ctx - .new_int(std::ptr::read_unaligned(ptr as *const u16) as i32) + .new_int(core::ptr::read_unaligned(ptr as *const u16) as i32) .into()), Some("i") | Some("l") => Ok(vm .ctx - .new_int(std::ptr::read_unaligned(ptr as *const i32)) + .new_int(core::ptr::read_unaligned(ptr as *const i32)) .into()), Some("I") | Some("L") => Ok(vm .ctx - .new_int(std::ptr::read_unaligned(ptr as *const u32)) + .new_int(core::ptr::read_unaligned(ptr as *const u32)) .into()), Some("q") => Ok(vm .ctx - .new_int(std::ptr::read_unaligned(ptr as *const i64)) + .new_int(core::ptr::read_unaligned(ptr as *const i64)) .into()), Some("Q") => Ok(vm .ctx - .new_int(std::ptr::read_unaligned(ptr as *const u64)) + .new_int(core::ptr::read_unaligned(ptr as *const u64)) .into()), Some("f") => Ok(vm .ctx - .new_float(std::ptr::read_unaligned(ptr as *const f32) as f64) + .new_float(core::ptr::read_unaligned(ptr as *const f32) as f64) .into()), Some("d") | Some("g") => Ok(vm .ctx - .new_float(std::ptr::read_unaligned(ptr as *const f64)) + .new_float(core::ptr::read_unaligned(ptr as *const f64)) .into()), Some("P") | Some("z") | Some("Z") => Ok(vm .ctx - .new_int(std::ptr::read_unaligned(ptr as *const usize)) + .new_int(core::ptr::read_unaligned(ptr as *const usize)) .into()), _ => { // Default: read as bytes - let bytes = std::slice::from_raw_parts(ptr, size).to_vec(); + let bytes = core::slice::from_raw_parts(ptr, size).to_vec(); Ok(vm.ctx.new_bytes(bytes).into()) } } @@ -708,7 +708,7 @@ impl PyCPointer { "bytes/string or integer address expected".to_owned(), )); }; - std::ptr::write_unaligned(ptr as *mut usize, ptr_val); + core::ptr::write_unaligned(ptr as *mut usize, ptr_val); return Ok(()); } _ => {} @@ -723,19 +723,19 @@ impl PyCPointer { *ptr = i.to_u8().expect("int too large"); } 2 => { - std::ptr::write_unaligned( + core::ptr::write_unaligned( ptr as *mut i16, i.to_i16().expect("int too large"), ); } 4 => { - std::ptr::write_unaligned( + core::ptr::write_unaligned( ptr as *mut i32, i.to_i32().expect("int too large"), ); } 8 => { - std::ptr::write_unaligned( + core::ptr::write_unaligned( ptr as *mut i64, i.to_i64().expect("int too large"), ); @@ -743,7 +743,7 @@ impl PyCPointer { _ => { let bytes = i.to_signed_bytes_le(); let copy_len = bytes.len().min(size); - std::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr, copy_len); + core::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr, copy_len); } } return Ok(()); @@ -754,10 +754,10 @@ impl PyCPointer { let f = float_val.to_f64(); match size { 4 => { - std::ptr::write_unaligned(ptr as *mut f32, f as f32); + core::ptr::write_unaligned(ptr as *mut f32, f as f32); } 8 => { - std::ptr::write_unaligned(ptr as *mut f64, f); + core::ptr::write_unaligned(ptr as *mut f64, f); } _ => {} } @@ -767,7 +767,7 @@ impl PyCPointer { // Try bytes if let Ok(bytes) = value.try_bytes_like(vm, |b| b.to_vec()) { let copy_len = bytes.len().min(size); - std::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr, copy_len); + core::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr, copy_len); return Ok(()); } diff --git a/crates/vm/src/stdlib/ctypes/simple.rs b/crates/vm/src/stdlib/ctypes/simple.rs index 803b38d6e05..26d24e2d3e4 100644 --- a/crates/vm/src/stdlib/ctypes/simple.rs +++ b/crates/vm/src/stdlib/ctypes/simple.rs @@ -13,9 +13,9 @@ use crate::function::{Either, FuncArgs, OptionalArg}; use crate::protocol::{BufferDescriptor, PyBuffer, PyNumberMethods}; use crate::types::{AsBuffer, AsNumber, Constructor, Initializer, Representable}; use crate::{AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine}; +use alloc::borrow::Cow; +use core::fmt::Debug; use num_traits::ToPrimitive; -use std::borrow::Cow; -use std::fmt::Debug; /// Valid type codes for ctypes simple types // spell-checker: disable-next-line @@ -27,22 +27,22 @@ pub(super) const SIMPLE_TYPE_CHARS: &str = "cbBhHiIlLdfuzZqQPXOv?g"; fn ctypes_code_to_pep3118(code: char) -> char { match code { // c_int: map based on sizeof(int) - 'i' if std::mem::size_of::() == 2 => 'h', - 'i' if std::mem::size_of::() == 4 => 'i', - 'i' if std::mem::size_of::() == 8 => 'q', - 'I' if std::mem::size_of::() == 2 => 'H', - 'I' if std::mem::size_of::() == 4 => 'I', - 'I' if std::mem::size_of::() == 8 => 'Q', + 'i' if core::mem::size_of::() == 2 => 'h', + 'i' if core::mem::size_of::() == 4 => 'i', + 'i' if core::mem::size_of::() == 8 => 'q', + 'I' if core::mem::size_of::() == 2 => 'H', + 'I' if core::mem::size_of::() == 4 => 'I', + 'I' if core::mem::size_of::() == 8 => 'Q', // c_long: map based on sizeof(long) - 'l' if std::mem::size_of::() == 4 => 'l', - 'l' if std::mem::size_of::() == 8 => 'q', - 'L' if std::mem::size_of::() == 4 => 'L', - 'L' if std::mem::size_of::() == 8 => 'Q', + 'l' if core::mem::size_of::() == 4 => 'l', + 'l' if core::mem::size_of::() == 8 => 'q', + 'L' if core::mem::size_of::() == 4 => 'L', + 'L' if core::mem::size_of::() == 8 => 'Q', // c_bool: map based on sizeof(bool) - typically 1 byte on all platforms - '?' if std::mem::size_of::() == 1 => '?', - '?' if std::mem::size_of::() == 2 => 'H', - '?' if std::mem::size_of::() == 4 => 'L', - '?' if std::mem::size_of::() == 8 => 'Q', + '?' if core::mem::size_of::() == 1 => '?', + '?' if core::mem::size_of::() == 2 => 'H', + '?' if core::mem::size_of::() == 4 => 'L', + '?' if core::mem::size_of::() == 8 => 'Q', // Default: use the same code _ => code, } @@ -268,7 +268,7 @@ impl PyCSimpleType { let create_simple_with_value = |type_str: &str, val: &PyObject| -> PyResult { let simple = new_simple_type(Either::B(&cls), vm)?; let buffer_bytes = value_to_bytes_endian(type_str, val, false, vm); - *simple.0.buffer.write() = std::borrow::Cow::Owned(buffer_bytes.clone()); + *simple.0.buffer.write() = alloc::borrow::Cow::Owned(buffer_bytes.clone()); let simple_obj: PyObjectRef = simple.into_ref_with_type(vm, cls.clone())?.into(); // from_param returns CArgObject, not the simple type itself let tag = type_str.as_bytes().first().copied().unwrap_or(b'?'); @@ -418,9 +418,9 @@ impl PyCSimpleType { if let Some(funcptr) = value.downcast_ref::() { let ptr_val = { let buffer = funcptr._base.buffer.read(); - if buffer.len() >= std::mem::size_of::() { + if buffer.len() >= core::mem::size_of::() { usize::from_ne_bytes( - buffer[..std::mem::size_of::()].try_into().unwrap(), + buffer[..core::mem::size_of::()].try_into().unwrap(), ) } else { 0 @@ -441,9 +441,9 @@ impl PyCSimpleType { if matches!(value_type_code.as_deref(), Some("z") | Some("Z")) { let ptr_val = { let buffer = simple.0.buffer.read(); - if buffer.len() >= std::mem::size_of::() { + if buffer.len() >= core::mem::size_of::() { usize::from_ne_bytes( - buffer[..std::mem::size_of::()].try_into().unwrap(), + buffer[..core::mem::size_of::()].try_into().unwrap(), ) } else { 0 @@ -712,7 +712,7 @@ fn create_swapped_types( pub struct PyCSimple(pub PyCData); impl Debug for PyCSimple { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("PyCSimple") .field("size", &self.0.buffer.read().len()) .finish() @@ -833,7 +833,7 @@ fn value_to_bytes_endian( let v = int_val.as_bigint().to_i128().expect("int too large") as libc::c_long; return to_bytes!(v); } - const SIZE: usize = std::mem::size_of::(); + const SIZE: usize = core::mem::size_of::(); vec![0; SIZE] } "L" => { @@ -842,7 +842,7 @@ fn value_to_bytes_endian( let v = int_val.as_bigint().to_i128().expect("int too large") as libc::c_ulong; return to_bytes!(v); } - const SIZE: usize = std::mem::size_of::(); + const SIZE: usize = core::mem::size_of::(); vec![0; SIZE] } "q" => { @@ -938,7 +938,7 @@ fn value_to_bytes_endian( .expect("int too large for pointer"); return to_bytes!(v); } - vec![0; std::mem::size_of::()] + vec![0; core::mem::size_of::()] } "z" => { // c_char_p - pointer to char (stores pointer value from int) @@ -950,7 +950,7 @@ fn value_to_bytes_endian( .expect("int too large for pointer"); return to_bytes!(v); } - vec![0; std::mem::size_of::()] + vec![0; core::mem::size_of::()] } "Z" => { // c_wchar_p - pointer to wchar_t (stores pointer value from int) @@ -962,7 +962,7 @@ fn value_to_bytes_endian( .expect("int too large for pointer"); return to_bytes!(v); } - vec![0; std::mem::size_of::()] + vec![0; core::mem::size_of::()] } "O" => { // py_object - store object id as non-zero marker @@ -1151,7 +1151,7 @@ impl PyCSimple { } // Read null-terminated string at the address unsafe { - let cstr = std::ffi::CStr::from_ptr(ptr as _); + let cstr = core::ffi::CStr::from_ptr(ptr as _); return Ok(vm.ctx.new_bytes(cstr.to_bytes().to_vec()).into()); } } @@ -1168,7 +1168,7 @@ impl PyCSimple { unsafe { let w_ptr = ptr as *const libc::wchar_t; let len = libc::wcslen(w_ptr); - let wchars = std::slice::from_raw_parts(w_ptr, len); + let wchars = core::slice::from_raw_parts(w_ptr, len); #[cfg(windows)] { use rustpython_common::wtf8::Wtf8Buf; @@ -1206,13 +1206,13 @@ impl PyCSimple { // Read value from buffer, swap bytes if needed let buffer = zelf.0.buffer.read(); - let buffer_data: std::borrow::Cow<'_, [u8]> = if swapped { + let buffer_data: alloc::borrow::Cow<'_, [u8]> = if swapped { // Reverse bytes for swapped endian types let mut swapped_bytes = buffer.to_vec(); swapped_bytes.reverse(); - std::borrow::Cow::Owned(swapped_bytes) + alloc::borrow::Cow::Owned(swapped_bytes) } else { - std::borrow::Cow::Borrowed(&*buffer) + alloc::borrow::Cow::Borrowed(&*buffer) }; let cls_ref = cls.to_owned(); @@ -1254,7 +1254,7 @@ impl PyCSimple { if type_code == "z" { if let Some(bytes) = value.downcast_ref::() { let (converted, ptr) = super::base::ensure_z_null_terminated(bytes, vm); - *zelf.0.buffer.write() = std::borrow::Cow::Owned(ptr.to_ne_bytes().to_vec()); + *zelf.0.buffer.write() = alloc::borrow::Cow::Owned(ptr.to_ne_bytes().to_vec()); *zelf.0.objects.write() = Some(converted); return Ok(()); } @@ -1262,7 +1262,7 @@ impl PyCSimple { && let Some(s) = value.downcast_ref::() { let (holder, ptr) = super::base::str_to_wchar_bytes(s.as_str(), vm); - *zelf.0.buffer.write() = std::borrow::Cow::Owned(ptr.to_ne_bytes().to_vec()); + *zelf.0.buffer.write() = alloc::borrow::Cow::Owned(ptr.to_ne_bytes().to_vec()); *zelf.0.objects.write() = Some(holder); return Ok(()); } @@ -1278,7 +1278,7 @@ impl PyCSimple { // Update buffer when value changes let buffer_bytes = value_to_bytes_endian(&type_code, &content, swapped, vm); - *zelf.0.buffer.write() = std::borrow::Cow::Owned(buffer_bytes); + *zelf.0.buffer.write() = alloc::borrow::Cow::Owned(buffer_bytes); // For c_char_p (type "z"), c_wchar_p (type "Z"), and py_object (type "O"), // keep the reference in _objects @@ -1336,65 +1336,65 @@ impl PyCSimple { let buffer = self.0.buffer.read(); let bytes: &[u8] = &buffer; - if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u8().as_raw_ptr()) { + if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u8().as_raw_ptr()) { if !bytes.is_empty() { return Some(FfiArgValue::U8(bytes[0])); } - } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i8().as_raw_ptr()) { + } else if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i8().as_raw_ptr()) { if !bytes.is_empty() { return Some(FfiArgValue::I8(bytes[0] as i8)); } - } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u16().as_raw_ptr()) { + } else if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u16().as_raw_ptr()) { if bytes.len() >= 2 { return Some(FfiArgValue::U16(u16::from_ne_bytes([bytes[0], bytes[1]]))); } - } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i16().as_raw_ptr()) { + } else if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i16().as_raw_ptr()) { if bytes.len() >= 2 { return Some(FfiArgValue::I16(i16::from_ne_bytes([bytes[0], bytes[1]]))); } - } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u32().as_raw_ptr()) { + } else if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u32().as_raw_ptr()) { if bytes.len() >= 4 { return Some(FfiArgValue::U32(u32::from_ne_bytes([ bytes[0], bytes[1], bytes[2], bytes[3], ]))); } - } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i32().as_raw_ptr()) { + } else if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i32().as_raw_ptr()) { if bytes.len() >= 4 { return Some(FfiArgValue::I32(i32::from_ne_bytes([ bytes[0], bytes[1], bytes[2], bytes[3], ]))); } - } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u64().as_raw_ptr()) { + } else if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::u64().as_raw_ptr()) { if bytes.len() >= 8 { return Some(FfiArgValue::U64(u64::from_ne_bytes([ bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], ]))); } - } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i64().as_raw_ptr()) { + } else if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::i64().as_raw_ptr()) { if bytes.len() >= 8 { return Some(FfiArgValue::I64(i64::from_ne_bytes([ bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], ]))); } - } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::f32().as_raw_ptr()) { + } else if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::f32().as_raw_ptr()) { if bytes.len() >= 4 { return Some(FfiArgValue::F32(f32::from_ne_bytes([ bytes[0], bytes[1], bytes[2], bytes[3], ]))); } - } else if std::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::f64().as_raw_ptr()) { + } else if core::ptr::eq(ty.as_raw_ptr(), libffi::middle::Type::f64().as_raw_ptr()) { if bytes.len() >= 8 { return Some(FfiArgValue::F64(f64::from_ne_bytes([ bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], ]))); } - } else if std::ptr::eq( + } else if core::ptr::eq( ty.as_raw_ptr(), libffi::middle::Type::pointer().as_raw_ptr(), - ) && bytes.len() >= std::mem::size_of::() + ) && bytes.len() >= core::mem::size_of::() { let val = - usize::from_ne_bytes(bytes[..std::mem::size_of::()].try_into().unwrap()); + usize::from_ne_bytes(bytes[..core::mem::size_of::()].try_into().unwrap()); return Some(FfiArgValue::Pointer(val)); } None diff --git a/crates/vm/src/stdlib/ctypes/structure.rs b/crates/vm/src/stdlib/ctypes/structure.rs index d5aca392c52..295ce0d87cf 100644 --- a/crates/vm/src/stdlib/ctypes/structure.rs +++ b/crates/vm/src/stdlib/ctypes/structure.rs @@ -6,9 +6,9 @@ use crate::function::PySetterValue; use crate::protocol::{BufferDescriptor, PyBuffer, PyNumberMethods}; use crate::types::{AsBuffer, AsNumber, Constructor, Initializer, SetAttr}; use crate::{AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine}; +use alloc::borrow::Cow; +use core::fmt::Debug; use num_traits::ToPrimitive; -use std::borrow::Cow; -use std::fmt::Debug; /// Calculate Structure type size from _fields_ (sum of field sizes) pub(super) fn calculate_struct_size(cls: &Py, vm: &VirtualMachine) -> PyResult { @@ -206,7 +206,7 @@ impl PyCStructType { { ( baseinfo.size, - std::cmp::max(baseinfo.align, forced_alignment), + core::cmp::max(baseinfo.align, forced_alignment), baseinfo.flags.contains(StgInfoFlags::TYPEFLAG_HASPOINTER), baseinfo.flags.contains(StgInfoFlags::TYPEFLAG_HASUNION), baseinfo.flags.contains(StgInfoFlags::TYPEFLAG_HASBITFIELD), @@ -252,7 +252,7 @@ impl PyCStructType { // Calculate effective alignment (PyCField_FromDesc) let effective_align = if pack > 0 { - std::cmp::min(pack, field_align) + core::cmp::min(pack, field_align) } else { field_align }; @@ -347,7 +347,7 @@ impl PyCStructType { } // Calculate total_align = max(max_align, forced_alignment) - let total_align = std::cmp::max(max_align, forced_alignment); + let total_align = core::cmp::max(max_align, forced_alignment); // Calculate aligned_size (PyCStructUnionType_update_stginfo) let aligned_size = if total_align > 0 { @@ -501,7 +501,7 @@ impl SetAttr for PyCStructType { pub struct PyCStructure(pub PyCData); impl Debug for PyCStructure { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("PyCStructure") .field("size", &self.0.size()) .finish() diff --git a/crates/vm/src/stdlib/ctypes/union.rs b/crates/vm/src/stdlib/ctypes/union.rs index 41bc7492a25..0da1ffee3fd 100644 --- a/crates/vm/src/stdlib/ctypes/union.rs +++ b/crates/vm/src/stdlib/ctypes/union.rs @@ -7,7 +7,7 @@ use crate::function::PySetterValue; use crate::protocol::{BufferDescriptor, PyBuffer}; use crate::types::{AsBuffer, Constructor, Initializer, SetAttr}; use crate::{AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine}; -use std::borrow::Cow; +use alloc::borrow::Cow; /// Calculate Union type size from _fields_ (max field size) pub(super) fn calculate_union_size(cls: &Py, vm: &VirtualMachine) -> PyResult { @@ -175,7 +175,7 @@ impl PyCUnionType { { ( baseinfo.size, - std::cmp::max(baseinfo.align, forced_alignment), + core::cmp::max(baseinfo.align, forced_alignment), baseinfo.flags.contains(StgInfoFlags::TYPEFLAG_HASPOINTER), baseinfo.flags.contains(StgInfoFlags::TYPEFLAG_HASBITFIELD), baseinfo.ffi_field_types.clone(), @@ -215,7 +215,7 @@ impl PyCUnionType { // Calculate effective alignment let effective_align = if pack > 0 { - std::cmp::min(pack, field_align) + core::cmp::min(pack, field_align) } else { field_align }; @@ -264,7 +264,7 @@ impl PyCUnionType { } // Calculate total_align and aligned_size - let total_align = std::cmp::max(max_align, forced_alignment); + let total_align = core::cmp::max(max_align, forced_alignment); let aligned_size = if total_align > 0 { max_size.div_ceil(total_align) * total_align } else { @@ -418,8 +418,8 @@ impl SetAttr for PyCUnionType { #[repr(transparent)] pub struct PyCUnion(pub PyCData); -impl std::fmt::Debug for PyCUnion { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for PyCUnion { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("PyCUnion") .field("size", &self.0.size()) .finish() diff --git a/crates/vm/src/stdlib/io.rs b/crates/vm/src/stdlib/io.rs index 89dc8bca925..bd451eff98e 100644 --- a/crates/vm/src/stdlib/io.rs +++ b/crates/vm/src/stdlib/io.rs @@ -53,7 +53,7 @@ impl ToOSErrorBuilder for std::io::Error { let msg = { let ptr = unsafe { libc::strerror(errno) }; if !ptr.is_null() { - unsafe { std::ffi::CStr::from_ptr(ptr) } + unsafe { core::ffi::CStr::from_ptr(ptr) } .to_string_lossy() .into_owned() } else { @@ -183,16 +183,16 @@ mod _io { }, vm::VirtualMachine, }; + use alloc::borrow::Cow; use bstr::ByteSlice; - use crossbeam_utils::atomic::AtomicCell; - use malachite_bigint::BigInt; - use num_traits::ToPrimitive; - use std::{ - borrow::Cow, - io::{self, Cursor, SeekFrom, prelude::*}, + use core::{ ops::Range, sync::atomic::{AtomicBool, Ordering}, }; + use crossbeam_utils::atomic::AtomicCell; + use malachite_bigint::BigInt; + use num_traits::ToPrimitive; + use std::io::{self, Cursor, SeekFrom, prelude::*}; #[allow(clippy::let_and_return)] fn validate_whence(whence: i32) -> bool { @@ -354,7 +354,7 @@ mod _io { // if we don't specify the number of bytes, or it's too big, give the whole rest of the slice let n = bytes.map_or_else( || avail_slice.len(), - |n| std::cmp::min(n, avail_slice.len()), + |n| core::cmp::min(n, avail_slice.len()), ); let b = avail_slice[..n].to_vec(); self.cursor.set_position((pos + n) as u64); @@ -1059,7 +1059,7 @@ mod _io { // TODO: loop if write() raises an interrupt vm.call_method(self.raw.as_ref().unwrap(), "write", (mem_obj,))? } else { - let v = std::mem::take(&mut self.buffer); + let v = core::mem::take(&mut self.buffer); let write_buf = VecBuffer::from(v).into_ref(&vm.ctx); let mem_obj = PyMemoryView::from_buffer_range( write_buf.clone().into_pybuffer(true), @@ -1330,7 +1330,7 @@ mod _io { let res = match v { Either::A(v) => { let v = v.unwrap_or(&mut self.buffer); - let read_buf = VecBuffer::from(std::mem::take(v)).into_ref(&vm.ctx); + let read_buf = VecBuffer::from(core::mem::take(v)).into_ref(&vm.ctx); let mem_obj = PyMemoryView::from_buffer_range( read_buf.clone().into_pybuffer(false), buf_range, @@ -1527,7 +1527,7 @@ mod _io { } else if !(readinto1 && written != 0) { let n = self.fill_buffer(vm)?; if let Some(n) = n.filter(|&n| n > 0) { - let n = std::cmp::min(n, remaining); + let n = core::cmp::min(n, remaining); buf.as_contiguous_mut().unwrap()[written..][..n] .copy_from_slice(&self.buffer[self.pos as usize..][..n]); self.pos += n as Offset; @@ -1881,7 +1881,7 @@ mod _io { } let have = data.readahead(); if have > 0 { - let n = std::cmp::min(have as usize, n); + let n = core::cmp::min(have as usize, n); return Ok(data.read_fast(n).unwrap()); } // Flush write buffer before reading @@ -2373,7 +2373,7 @@ mod _io { } } - impl std::ops::Add for Utf8size { + impl core::ops::Add for Utf8size { type Output = Self; #[inline] @@ -2383,7 +2383,7 @@ mod _io { } } - impl std::ops::AddAssign for Utf8size { + impl core::ops::AddAssign for Utf8size { #[inline] fn add_assign(&mut self, rhs: Self) { self.bytes += rhs.bytes; @@ -2391,7 +2391,7 @@ mod _io { } } - impl std::ops::Sub for Utf8size { + impl core::ops::Sub for Utf8size { type Output = Self; #[inline] @@ -2401,7 +2401,7 @@ mod _io { } } - impl std::ops::SubAssign for Utf8size { + impl core::ops::SubAssign for Utf8size { #[inline] fn sub_assign(&mut self, rhs: Self) { self.bytes -= rhs.bytes; @@ -2470,7 +2470,7 @@ mod _io { impl PendingWrites { fn push(&mut self, write: PendingWrite) { self.num_bytes += write.as_bytes().len(); - self.data = match std::mem::take(&mut self.data) { + self.data = match core::mem::take(&mut self.data) { PendingWritesData::None => PendingWritesData::One(write), PendingWritesData::One(write1) => PendingWritesData::Many(vec![write1, write]), PendingWritesData::Many(mut v) => { @@ -2480,13 +2480,13 @@ mod _io { } } fn take(&mut self, vm: &VirtualMachine) -> PyBytesRef { - let Self { num_bytes, data } = std::mem::take(self); + let Self { num_bytes, data } = core::mem::take(self); if let PendingWritesData::One(PendingWrite::Bytes(b)) = data { return b; } let writes_iter = match data { PendingWritesData::None => itertools::Either::Left(vec![].into_iter()), - PendingWritesData::One(write) => itertools::Either::Right(std::iter::once(write)), + PendingWritesData::One(write) => itertools::Either::Right(core::iter::once(write)), PendingWritesData::Many(writes) => itertools::Either::Left(writes.into_iter()), }; let mut buf = Vec::with_capacity(num_bytes); @@ -2508,7 +2508,7 @@ mod _io { impl TextIOCookie { const START_POS_OFF: usize = 0; - const DEC_FLAGS_OFF: usize = Self::START_POS_OFF + std::mem::size_of::(); + const DEC_FLAGS_OFF: usize = Self::START_POS_OFF + core::mem::size_of::(); const BYTES_TO_FEED_OFF: usize = Self::DEC_FLAGS_OFF + 4; const CHARS_TO_SKIP_OFF: usize = Self::BYTES_TO_FEED_OFF + 4; const NEED_EOF_OFF: usize = Self::CHARS_TO_SKIP_OFF + 4; @@ -2525,7 +2525,7 @@ mod _io { macro_rules! get_field { ($t:ty, $off:ident) => {{ <$t>::from_ne_bytes( - buf[Self::$off..][..std::mem::size_of::<$t>()] + buf[Self::$off..][..core::mem::size_of::<$t>()] .try_into() .unwrap(), ) @@ -2546,7 +2546,7 @@ mod _io { macro_rules! set_field { ($field:expr, $off:ident) => {{ let field = $field; - buf[Self::$off..][..std::mem::size_of_val(&field)] + buf[Self::$off..][..core::mem::size_of_val(&field)] .copy_from_slice(&field.to_ne_bytes()) }}; } @@ -3509,7 +3509,7 @@ mod _io { } else { size_hint }; - let chunk_size = std::cmp::max(self.chunk_size, size_hint); + let chunk_size = core::cmp::max(self.chunk_size, size_hint); let input_chunk = vm.call_method(&self.buffer, method, (chunk_size,))?; let buf = ArgBytesLike::try_from_borrowed_object(vm, &input_chunk).map_err(|_| { @@ -3591,8 +3591,8 @@ mod _io { vm: &VirtualMachine, ) -> PyStrRef { let empty_str = || vm.ctx.empty_str.to_owned(); - let chars_pos = std::mem::take(&mut self.decoded_chars_used).bytes; - let decoded_chars = match std::mem::take(&mut self.decoded_chars) { + let chars_pos = core::mem::take(&mut self.decoded_chars_used).bytes; + let decoded_chars = match core::mem::take(&mut self.decoded_chars) { None => return append.unwrap_or_else(empty_str), Some(s) if s.is_empty() => return append.unwrap_or_else(empty_str), Some(s) => s, @@ -4294,7 +4294,7 @@ mod _io { plus: bool, } - impl std::str::FromStr for Mode { + impl core::str::FromStr for Mode { type Err = ParseModeError; fn from_str(s: &str) -> Result { diff --git a/crates/vm/src/stdlib/itertools.rs b/crates/vm/src/stdlib/itertools.rs index 3fedd17f12b..3aad2f91931 100644 --- a/crates/vm/src/stdlib/itertools.rs +++ b/crates/vm/src/stdlib/itertools.rs @@ -25,8 +25,8 @@ mod decl { use malachite_bigint::BigInt; use num_traits::One; + use alloc::fmt; use num_traits::{Signed, ToPrimitive}; - use std::fmt; fn pickle_deprecation(vm: &VirtualMachine) -> PyResult<()> { warnings::warn( @@ -1320,7 +1320,7 @@ mod decl { for arg in iterables.iter() { pools.push(arg.try_to_value(vm)?); } - let pools = std::iter::repeat_n(pools, repeat) + let pools = core::iter::repeat_n(pools, repeat) .flatten() .collect::>>(); diff --git a/crates/vm/src/stdlib/mod.rs b/crates/vm/src/stdlib/mod.rs index 9fae516fe04..e46f333a28b 100644 --- a/crates/vm/src/stdlib/mod.rs +++ b/crates/vm/src/stdlib/mod.rs @@ -62,7 +62,8 @@ mod winapi; mod winreg; use crate::{PyRef, VirtualMachine, builtins::PyModule}; -use std::{borrow::Cow, collections::HashMap}; +use alloc::borrow::Cow; +use std::collections::HashMap; pub type StdlibInitFunc = Box PyRef)>; pub type StdlibMap = HashMap, StdlibInitFunc, ahash::RandomState>; diff --git a/crates/vm/src/stdlib/nt.rs b/crates/vm/src/stdlib/nt.rs index e056142658d..a32959808c0 100644 --- a/crates/vm/src/stdlib/nt.rs +++ b/crates/vm/src/stdlib/nt.rs @@ -497,7 +497,7 @@ pub(crate) mod module { wide_path.as_ptr(), FILE_READ_ATTRIBUTES, FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, - std::ptr::null(), + core::ptr::null(), OPEN_EXISTING, flags, std::ptr::null_mut(), @@ -517,7 +517,7 @@ pub(crate) mod module { wide_path.as_ptr(), FILE_READ_ATTRIBUTES, FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, - std::ptr::null(), + core::ptr::null(), OPEN_EXISTING, FILE_FLAG_OPEN_REPARSE_POINT, std::ptr::null_mut(), @@ -568,7 +568,7 @@ pub(crate) mod module { wide_path.as_ptr(), FILE_READ_ATTRIBUTES, FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, - std::ptr::null(), + core::ptr::null(), OPEN_EXISTING, flags, std::ptr::null_mut(), @@ -586,7 +586,7 @@ pub(crate) mod module { wide_path.as_ptr(), GENERIC_READ, FILE_SHARE_READ | FILE_SHARE_WRITE, - std::ptr::null(), + core::ptr::null(), OPEN_EXISTING, 0, std::ptr::null_mut(), @@ -733,7 +733,7 @@ pub(crate) mod module { volume.as_ptr(), FILE_READ_ATTRIBUTES, FILE_SHARE_READ | FILE_SHARE_WRITE, - std::ptr::null(), + core::ptr::null(), OPEN_EXISTING, FILE_FLAG_BACKUP_SEMANTICS, std::ptr::null_mut(), @@ -862,7 +862,7 @@ pub(crate) mod module { conout.as_ptr(), Foundation::GENERIC_READ | Foundation::GENERIC_WRITE, FileSystem::FILE_SHARE_READ | FileSystem::FILE_SHARE_WRITE, - std::ptr::null(), + core::ptr::null(), FileSystem::OPEN_EXISTING, 0, std::ptr::null_mut(), @@ -933,7 +933,7 @@ pub(crate) mod module { let argv_spawn: Vec<*const u16> = argv .iter() .map(|v| v.as_ptr()) - .chain(once(std::ptr::null())) + .chain(once(core::ptr::null())) .collect(); let result = unsafe { suppress_iph!(_wspawnv(mode, path.as_ptr(), argv_spawn.as_ptr())) }; @@ -976,7 +976,7 @@ pub(crate) mod module { let argv_spawn: Vec<*const u16> = argv .iter() .map(|v| v.as_ptr()) - .chain(once(std::ptr::null())) + .chain(once(core::ptr::null())) .collect(); // Build environment strings as "KEY=VALUE\0" wide strings @@ -1004,7 +1004,7 @@ pub(crate) mod module { let envp: Vec<*const u16> = env_strings .iter() .map(|s| s.as_ptr()) - .chain(once(std::ptr::null())) + .chain(once(core::ptr::null())) .collect(); let result = unsafe { @@ -1052,7 +1052,7 @@ pub(crate) mod module { let argv_execv: Vec<*const u16> = argv .iter() .map(|v| v.as_ptr()) - .chain(once(std::ptr::null())) + .chain(once(core::ptr::null())) .collect(); if (unsafe { suppress_iph!(_wexecv(path.as_ptr(), argv_execv.as_ptr())) } == -1) { @@ -1093,7 +1093,7 @@ pub(crate) mod module { let argv_execve: Vec<*const u16> = argv .iter() .map(|v| v.as_ptr()) - .chain(once(std::ptr::null())) + .chain(once(core::ptr::null())) .collect(); // Build environment strings as "KEY=VALUE\0" wide strings @@ -1121,7 +1121,7 @@ pub(crate) mod module { let envp: Vec<*const u16> = env_strings .iter() .map(|s| s.as_ptr()) - .chain(once(std::ptr::null())) + .chain(once(core::ptr::null())) .collect(); if (unsafe { suppress_iph!(_wexecve(path.as_ptr(), argv_execve.as_ptr(), envp.as_ptr())) } @@ -1356,7 +1356,7 @@ pub(crate) mod module { .chain(std::iter::once(0)) // null-terminated .collect(); - let mut end: *const u16 = std::ptr::null(); + let mut end: *const u16 = core::ptr::null(); let hr = unsafe { windows_sys::Win32::UI::Shell::PathCchSkipRoot(backslashed.as_ptr(), &mut end) }; @@ -1667,7 +1667,7 @@ pub(crate) mod module { let res = CreatePipe( read.as_mut_ptr() as *mut _, write.as_mut_ptr() as *mut _, - std::ptr::null(), + core::ptr::null(), 0, ); if res == 0 { @@ -1723,7 +1723,7 @@ pub(crate) mod module { let Some(func) = func else { return 0; }; - let nt_query: NtQueryInformationProcessFn = unsafe { std::mem::transmute(func) }; + let nt_query: NtQueryInformationProcessFn = unsafe { core::mem::transmute(func) }; let mut info: PROCESS_BASIC_INFORMATION = unsafe { std::mem::zeroed() }; @@ -1808,7 +1808,7 @@ pub(crate) mod module { wide_path.as_ptr(), 0, // No access needed, just reading reparse data FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, - std::ptr::null(), + core::ptr::null(), OPEN_EXISTING, FILE_FLAG_BACKUP_SEMANTICS | FILE_FLAG_OPEN_REPARSE_POINT, std::ptr::null_mut(), @@ -1832,7 +1832,7 @@ pub(crate) mod module { DeviceIoControl( handle, FSCTL_GET_REPARSE_POINT, - std::ptr::null(), + core::ptr::null(), 0, buffer.as_mut_ptr() as *mut _, BUFFER_SIZE as u32, diff --git a/crates/vm/src/stdlib/os.rs b/crates/vm/src/stdlib/os.rs index 6849ea365df..f76ebec06ac 100644 --- a/crates/vm/src/stdlib/os.rs +++ b/crates/vm/src/stdlib/os.rs @@ -171,15 +171,10 @@ pub(super) mod _os { utils::ToCString, vm::VirtualMachine, }; + use core::time::Duration; use crossbeam_utils::atomic::AtomicCell; use itertools::Itertools; - use std::{ - env, fs, - fs::OpenOptions, - io, - path::PathBuf, - time::{Duration, SystemTime}, - }; + use std::{env, fs, fs::OpenOptions, io, path::PathBuf, time::SystemTime}; const OPEN_DIR_FD: bool = cfg!(not(any(windows, target_os = "redox"))); pub(crate) const MKDIR_DIR_FD: bool = cfg!(not(any(windows, target_os = "redox"))); @@ -509,7 +504,7 @@ pub(super) mod _os { 22, format!( "Invalid argument: {}", - std::str::from_utf8(key).unwrap_or("") + core::str::from_utf8(key).unwrap_or("") ), ); @@ -1042,12 +1037,12 @@ pub(super) mod _os { dir_fd: DirFd<'_, { STAT_DIR_FD as usize }>, follow_symlinks: FollowSymlinks, ) -> io::Result> { - let mut stat = std::mem::MaybeUninit::uninit(); + let mut stat = core::mem::MaybeUninit::uninit(); let ret = match file { OsPathOrFd::Path(path) => { use rustpython_common::os::ffi::OsStrExt; let path = path.as_ref().as_os_str().as_bytes(); - let path = match std::ffi::CString::new(path) { + let path = match alloc::ffi::CString::new(path) { Ok(x) => x, Err(_) => return Ok(None), }; @@ -1209,7 +1204,7 @@ pub(super) mod _os { use std::os::windows::io::AsRawHandle; use windows_sys::Win32::Storage::FileSystem; let handle = crt_fd::as_handle(fd).map_err(|e| e.into_pyexception(vm))?; - let mut distance_to_move: [i32; 2] = std::mem::transmute(position); + let mut distance_to_move: [i32; 2] = core::mem::transmute(position); let ret = FileSystem::SetFilePointer( handle.as_raw_handle(), distance_to_move[0], @@ -1220,7 +1215,7 @@ pub(super) mod _os { -1 } else { distance_to_move[0] = ret as _; - std::mem::transmute::<[i32; 2], i64>(distance_to_move) + core::mem::transmute::<[i32; 2], i64>(distance_to_move) } }; if res < 0 { @@ -1402,7 +1397,7 @@ pub(super) mod _os { .map_err(|err| OSErrorBuilder::with_filename(&err, path.clone(), vm))?; let ret = unsafe { - FileSystem::SetFileTime(f.as_raw_handle() as _, std::ptr::null(), &acc, &modif) + FileSystem::SetFileTime(f.as_raw_handle() as _, core::ptr::null(), &acc, &modif) }; if ret == 0 { @@ -1523,9 +1518,9 @@ pub(super) mod _os { #[pyfunction] fn copy_file_range(args: CopyFileRangeArgs<'_>, vm: &VirtualMachine) -> PyResult { #[allow(clippy::unnecessary_option_map_or_else)] - let p_offset_src = args.offset_src.as_ref().map_or_else(std::ptr::null, |x| x); + let p_offset_src = args.offset_src.as_ref().map_or_else(core::ptr::null, |x| x); #[allow(clippy::unnecessary_option_map_or_else)] - let p_offset_dst = args.offset_dst.as_ref().map_or_else(std::ptr::null, |x| x); + let p_offset_dst = args.offset_dst.as_ref().map_or_else(core::ptr::null, |x| x); let count: usize = args .count .try_into() @@ -1557,7 +1552,7 @@ pub(super) mod _os { #[pyfunction] fn strerror(e: i32) -> String { - unsafe { std::ffi::CStr::from_ptr(libc::strerror(e)) } + unsafe { core::ffi::CStr::from_ptr(libc::strerror(e)) } .to_string_lossy() .into_owned() } @@ -1661,7 +1656,7 @@ pub(super) mod _os { if encoding.is_null() || encoding.read() == '\0' as libc::c_char { "UTF-8".to_owned() } else { - std::ffi::CStr::from_ptr(encoding).to_string_lossy().into_owned() + core::ffi::CStr::from_ptr(encoding).to_string_lossy().into_owned() } }; diff --git a/crates/vm/src/stdlib/posix.rs b/crates/vm/src/stdlib/posix.rs index efbd0cf9049..6414242ada9 100644 --- a/crates/vm/src/stdlib/posix.rs +++ b/crates/vm/src/stdlib/posix.rs @@ -33,15 +33,15 @@ pub mod module { types::{Constructor, Representable}, utils::ToCString, }; + use alloc::ffi::CString; use bitflags::bitflags; + use core::ffi::CStr; use nix::{ fcntl, unistd::{self, Gid, Pid, Uid}, }; use std::{ - env, - ffi::{CStr, CString}, - fs, io, + env, fs, io, os::fd::{AsFd, BorrowedFd, FromRawFd, IntoRawFd, OwnedFd}, }; use strum_macros::{EnumIter, EnumString}; @@ -917,7 +917,7 @@ pub mod module { #[pyfunction] fn sched_getparam(pid: libc::pid_t, vm: &VirtualMachine) -> PyResult { let param = unsafe { - let mut param = std::mem::MaybeUninit::uninit(); + let mut param = core::mem::MaybeUninit::uninit(); if -1 == libc::sched_getparam(pid, param.as_mut_ptr()) { return Err(vm.new_last_errno_error()); } @@ -1280,7 +1280,7 @@ pub mod module { } fn try_from_id(vm: &VirtualMachine, obj: PyObjectRef, typ_name: &str) -> PyResult { - use std::cmp::Ordering; + use core::cmp::Ordering; let i = obj .try_to_ref::(vm) .map_err(|_| { @@ -1838,9 +1838,9 @@ pub mod module { #[pyfunction] fn dup2(args: Dup2Args<'_>, vm: &VirtualMachine) -> PyResult { - let mut fd2 = std::mem::ManuallyDrop::new(args.fd2); + let mut fd2 = core::mem::ManuallyDrop::new(args.fd2); nix::unistd::dup2(args.fd, &mut fd2).map_err(|e| e.into_pyexception(vm))?; - let fd2 = std::mem::ManuallyDrop::into_inner(fd2); + let fd2 = core::mem::ManuallyDrop::into_inner(fd2); if !args.inheritable { super::set_inheritable(fd2.as_fd(), false).map_err(|e| e.into_pyexception(vm))? } diff --git a/crates/vm/src/stdlib/pwd.rs b/crates/vm/src/stdlib/pwd.rs index e4d7075dbc8..6405ed7be91 100644 --- a/crates/vm/src/stdlib/pwd.rs +++ b/crates/vm/src/stdlib/pwd.rs @@ -37,7 +37,7 @@ mod pwd { impl From for PasswdData { fn from(user: User) -> Self { // this is just a pain... - let cstr_lossy = |s: std::ffi::CString| { + let cstr_lossy = |s: alloc::ffi::CString| { s.into_string() .unwrap_or_else(|e| e.into_cstring().to_string_lossy().into_owned()) }; @@ -105,7 +105,7 @@ mod pwd { let mut list = Vec::new(); unsafe { libc::setpwent() }; - while let Some(ptr) = std::ptr::NonNull::new(unsafe { libc::getpwent() }) { + while let Some(ptr) = core::ptr::NonNull::new(unsafe { libc::getpwent() }) { let user = User::from(unsafe { ptr.as_ref() }); let passwd = PasswdData::from(user).to_pyobject(vm); list.push(passwd); diff --git a/crates/vm/src/stdlib/signal.rs b/crates/vm/src/stdlib/signal.rs index 810ffabefe6..dd0d9a7a96f 100644 --- a/crates/vm/src/stdlib/signal.rs +++ b/crates/vm/src/stdlib/signal.rs @@ -24,7 +24,7 @@ pub(crate) mod _signal { builtins::PyTypeRef, function::{ArgIntoFloat, OptionalArg}, }; - use std::sync::atomic::{self, Ordering}; + use core::sync::atomic::{self, Ordering}; #[cfg(any(unix, windows))] use libc::sighandler_t; @@ -301,7 +301,7 @@ pub(crate) mod _signal { it_value: double_to_timeval(seconds), it_interval: double_to_timeval(interval), }; - let mut old = std::mem::MaybeUninit::::uninit(); + let mut old = core::mem::MaybeUninit::::uninit(); #[cfg(any(target_os = "linux", target_os = "android"))] let ret = unsafe { ffi::setitimer(which, &new, old.as_mut_ptr()) }; #[cfg(not(any(target_os = "linux", target_os = "android")))] @@ -318,7 +318,7 @@ pub(crate) mod _signal { #[cfg(unix)] #[pyfunction] fn getitimer(which: i32, vm: &VirtualMachine) -> PyResult<(f64, f64)> { - let mut old = std::mem::MaybeUninit::::uninit(); + let mut old = core::mem::MaybeUninit::::uninit(); #[cfg(any(target_os = "linux", target_os = "android"))] let ret = unsafe { ffi::getitimer(which, old.as_mut_ptr()) }; #[cfg(not(any(target_os = "linux", target_os = "android")))] @@ -489,7 +489,7 @@ pub(crate) mod _signal { if s.is_null() { Ok(None) } else { - let cstr = unsafe { std::ffi::CStr::from_ptr(s) }; + let cstr = unsafe { core::ffi::CStr::from_ptr(s) }; Ok(Some(cstr.to_string_lossy().into_owned())) } } @@ -522,7 +522,7 @@ pub(crate) mod _signal { #[cfg(unix)] { // Use sigfillset to get all valid signals - let mut mask: libc::sigset_t = unsafe { std::mem::zeroed() }; + let mut mask: libc::sigset_t = unsafe { core::mem::zeroed() }; // SAFETY: mask is a valid pointer if unsafe { libc::sigfillset(&mut mask) } != 0 { return Err(vm.new_os_error("sigfillset failed".to_owned())); @@ -580,7 +580,7 @@ pub(crate) mod _signal { use crate::convert::IntoPyException; // Initialize sigset - let mut sigset: libc::sigset_t = unsafe { std::mem::zeroed() }; + let mut sigset: libc::sigset_t = unsafe { core::mem::zeroed() }; // SAFETY: sigset is a valid pointer if unsafe { libc::sigemptyset(&mut sigset) } != 0 { return Err(std::io::Error::last_os_error().into_pyexception(vm)); @@ -611,7 +611,7 @@ pub(crate) mod _signal { } // Call pthread_sigmask - let mut old_mask: libc::sigset_t = unsafe { std::mem::zeroed() }; + let mut old_mask: libc::sigset_t = unsafe { core::mem::zeroed() }; // SAFETY: all pointers are valid let err = unsafe { libc::pthread_sigmask(how, &sigset, &mut old_mask) }; if err != 0 { diff --git a/crates/vm/src/stdlib/string.rs b/crates/vm/src/stdlib/string.rs index 576cae62775..a9911f3d45f 100644 --- a/crates/vm/src/stdlib/string.rs +++ b/crates/vm/src/stdlib/string.rs @@ -16,7 +16,7 @@ mod _string { convert::ToPyException, convert::ToPyObject, }; - use std::mem; + use core::mem; fn create_format_part( literal: Wtf8Buf, diff --git a/crates/vm/src/stdlib/symtable.rs b/crates/vm/src/stdlib/symtable.rs index 8a142857787..51c5c8e47ea 100644 --- a/crates/vm/src/stdlib/symtable.rs +++ b/crates/vm/src/stdlib/symtable.rs @@ -7,10 +7,10 @@ mod symtable { builtins::{PyDictRef, PyStrRef}, compiler, }; + use alloc::fmt; use rustpython_codegen::symboltable::{ CompilerScope, Symbol, SymbolFlags, SymbolScope, SymbolTable, }; - use std::fmt; // Consts as defined at // https://github.com/python/cpython/blob/6cb20a219a860eaf687b2d968b41c480c7461909/Include/internal/pycore_symtable.h#L156 @@ -180,7 +180,7 @@ mod symtable { #[pygetset] fn id(&self) -> usize { - self as *const Self as *const std::ffi::c_void as usize + self as *const Self as *const core::ffi::c_void as usize } #[pygetset] diff --git a/crates/vm/src/stdlib/sys.rs b/crates/vm/src/stdlib/sys.rs index 0e46ec18a01..df82bd6416c 100644 --- a/crates/vm/src/stdlib/sys.rs +++ b/crates/vm/src/stdlib/sys.rs @@ -23,11 +23,11 @@ mod sys { version, vm::{Settings, VirtualMachine}, }; + use core::sync::atomic::Ordering; use num_traits::ToPrimitive; use std::{ env::{self, VarError}, io::Read, - sync::atomic::Ordering, }; #[cfg(windows)] @@ -66,7 +66,7 @@ mod sys { #[pyattr(name = "maxsize")] pub(crate) const MAXSIZE: isize = isize::MAX; #[pyattr(name = "maxunicode")] - const MAXUNICODE: u32 = std::char::MAX as u32; + const MAXUNICODE: u32 = core::char::MAX as u32; #[pyattr(name = "platform")] pub(crate) const PLATFORM: &str = { cfg_if::cfg_if! { @@ -446,7 +446,7 @@ mod sys { let sizeof = || -> PyResult { let res = vm.call_special_method(&args.obj, identifier!(vm, __sizeof__), ())?; let res = res.try_index(vm)?.try_to_primitive::(vm)?; - Ok(res + std::mem::size_of::()) + Ok(res + core::mem::size_of::()) }; sizeof() .map(|x| vm.ctx.new_int(x).into()) @@ -1072,7 +1072,7 @@ mod sys { const INFO: Self = { use rustpython_common::hash::*; Self { - width: std::mem::size_of::() * 8, + width: core::mem::size_of::() * 8, modulus: MODULUS, inf: INF, nan: NAN, @@ -1102,7 +1102,7 @@ mod sys { impl IntInfoData { const INFO: Self = Self { bits_per_digit: 30, //? - sizeof_digit: std::mem::size_of::(), + sizeof_digit: core::mem::size_of::(), default_max_str_digits: 4300, str_digits_check_threshold: 640, }; @@ -1220,7 +1220,7 @@ pub(crate) fn init_module(vm: &VirtualMachine, module: &Py, builtins: pub struct PyStderr<'vm>(pub &'vm VirtualMachine); impl PyStderr<'_> { - pub fn write_fmt(&self, args: std::fmt::Arguments<'_>) { + pub fn write_fmt(&self, args: core::fmt::Arguments<'_>) { use crate::py_io::Write; let vm = self.0; diff --git a/crates/vm/src/stdlib/thread.rs b/crates/vm/src/stdlib/thread.rs index 36252279397..d22cbb3b0c1 100644 --- a/crates/vm/src/stdlib/thread.rs +++ b/crates/vm/src/stdlib/thread.rs @@ -11,12 +11,14 @@ pub(crate) mod _thread { function::{ArgCallable, Either, FuncArgs, KwArgs, OptionalArg, PySetterValue}, types::{Constructor, GetAttr, Representable, SetAttr}, }; + use alloc::fmt; + use core::{cell::RefCell, time::Duration}; use crossbeam_utils::atomic::AtomicCell; use parking_lot::{ RawMutex, RawThreadId, lock_api::{RawMutex as RawMutexT, RawMutexTimed, RawReentrantMutex}, }; - use std::{cell::RefCell, fmt, thread, time::Duration}; + use std::thread; use thread_local::ThreadLocal; // PYTHREAD_NAME: show current thread name @@ -151,7 +153,7 @@ pub(crate) mod _thread { let new_mut = RawMutex::INIT; unsafe { - let old_mutex: &AtomicCell = std::mem::transmute(&self.mu); + let old_mutex: &AtomicCell = core::mem::transmute(&self.mu); old_mutex.swap(new_mut); } @@ -264,7 +266,7 @@ pub(crate) mod _thread { } fn thread_to_id(t: &thread::Thread) -> u64 { - use std::hash::{Hash, Hasher}; + use core::hash::{Hash, Hasher}; struct U64Hash { v: Option, } diff --git a/crates/vm/src/stdlib/time.rs b/crates/vm/src/stdlib/time.rs index b9b53cdc5c5..97d60ae98a1 100644 --- a/crates/vm/src/stdlib/time.rs +++ b/crates/vm/src/stdlib/time.rs @@ -21,12 +21,12 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { unsafe extern "C" { #[cfg(not(target_os = "freebsd"))] #[link_name = "daylight"] - static c_daylight: std::ffi::c_int; + static c_daylight: core::ffi::c_int; // pub static dstbias: std::ffi::c_int; #[link_name = "timezone"] - static c_timezone: std::ffi::c_long; + static c_timezone: core::ffi::c_long; #[link_name = "tzname"] - static c_tzname: [*const std::ffi::c_char; 2]; + static c_tzname: [*const core::ffi::c_char; 2]; #[link_name = "tzset"] fn c_tzset(); } @@ -43,7 +43,7 @@ mod decl { DateTime, Datelike, TimeZone, Timelike, naive::{NaiveDate, NaiveDateTime, NaiveTime}, }; - use std::time::Duration; + use core::time::Duration; #[cfg(target_env = "msvc")] #[cfg(not(target_arch = "wasm32"))] use windows_sys::Win32::System::Time::{GetTimeZoneInformation, TIME_ZONE_INFORMATION}; @@ -104,7 +104,7 @@ mod decl { { // this is basically std::thread::sleep, but that catches interrupts and we don't want to; let ts = nix::sys::time::TimeSpec::from(dur); - let res = unsafe { libc::nanosleep(ts.as_ref(), std::ptr::null_mut()) }; + let res = unsafe { libc::nanosleep(ts.as_ref(), core::ptr::null_mut()) }; let interrupted = res == -1 && nix::Error::last_raw() == libc::EINTR; if interrupted { @@ -200,7 +200,7 @@ mod decl { #[cfg(not(target_env = "msvc"))] #[cfg(not(target_arch = "wasm32"))] #[pyattr] - fn timezone(_vm: &VirtualMachine) -> std::ffi::c_long { + fn timezone(_vm: &VirtualMachine) -> core::ffi::c_long { unsafe { super::c_timezone } } @@ -217,7 +217,7 @@ mod decl { #[cfg(not(target_env = "msvc"))] #[cfg(not(target_arch = "wasm32"))] #[pyattr] - fn daylight(_vm: &VirtualMachine) -> std::ffi::c_int { + fn daylight(_vm: &VirtualMachine) -> core::ffi::c_int { unsafe { super::c_daylight } } @@ -236,8 +236,8 @@ mod decl { fn tzname(vm: &VirtualMachine) -> crate::builtins::PyTupleRef { use crate::builtins::tuple::IntoPyTuple; - unsafe fn to_str(s: *const std::ffi::c_char) -> String { - unsafe { std::ffi::CStr::from_ptr(s) } + unsafe fn to_str(s: *const core::ffi::c_char) -> String { + unsafe { core::ffi::CStr::from_ptr(s) } .to_string_lossy() .into_owned() } @@ -357,7 +357,7 @@ mod decl { t: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - use std::fmt::Write; + use core::fmt::Write; let instant = t.naive_or_local(vm)?; @@ -500,8 +500,8 @@ mod decl { pub tm_zone: PyObjectRef, } - impl std::fmt::Debug for StructTimeData { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + impl core::fmt::Debug for StructTimeData { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "struct_time()") } } @@ -590,8 +590,8 @@ mod platform { builtins::{PyNamespace, PyStrRef}, convert::IntoPyException, }; + use core::time::Duration; use nix::{sys::time::TimeSpec, time::ClockId}; - use std::time::Duration; #[cfg(target_os = "solaris")] #[pyattr] @@ -818,7 +818,7 @@ mod platform { fn u64_from_filetime(time: FILETIME) -> u64 { let large: [u32; 2] = [time.dwLowDateTime, time.dwHighDateTime]; - unsafe { std::mem::transmute(large) } + unsafe { core::mem::transmute(large) } } fn win_perf_counter_frequency(vm: &VirtualMachine) -> PyResult { diff --git a/crates/vm/src/stdlib/winapi.rs b/crates/vm/src/stdlib/winapi.rs index 5cfb62fad6f..f01843f6f62 100644 --- a/crates/vm/src/stdlib/winapi.rs +++ b/crates/vm/src/stdlib/winapi.rs @@ -107,7 +107,7 @@ mod _winapi { WindowsSysResult(windows_sys::Win32::System::Pipes::CreatePipe( read.as_mut_ptr(), write.as_mut_ptr(), - std::ptr::null(), + core::ptr::null(), size, )) .to_pyresult(vm)?; @@ -278,8 +278,8 @@ mod _winapi { WindowsSysResult(windows_sys::Win32::System::Threading::CreateProcessW( app_name, command_line, - std::ptr::null(), - std::ptr::null(), + core::ptr::null(), + core::ptr::null(), args.inherit_handles, args.creation_flags | windows_sys::Win32::System::Threading::EXTENDED_STARTUPINFO_PRESENT @@ -454,7 +454,7 @@ mod _winapi { handlelist.as_mut_ptr() as _, (handlelist.len() * std::mem::size_of::()) as _, std::ptr::null_mut(), - std::ptr::null(), + core::ptr::null(), ) }) .into_pyresult(vm)?; @@ -873,7 +873,7 @@ mod _winapi { } let buf = buffer.borrow_buf(); - let len = std::cmp::min(buf.len(), u32::MAX as usize) as u32; + let len = core::cmp::min(buf.len(), u32::MAX as usize) as u32; let mut written: u32 = 0; let ret = unsafe { @@ -947,7 +947,7 @@ mod _winapi { let mut batches: Vec> = Vec::new(); let mut i = 0; while i < nhandles { - let end = std::cmp::min(i + batch_size, nhandles); + let end = core::cmp::min(i + batch_size, nhandles); batches.push(handles[i..end].to_vec()); i = end; } diff --git a/crates/vm/src/stdlib/winreg.rs b/crates/vm/src/stdlib/winreg.rs index b5e568fce6d..f3d8ca10768 100644 --- a/crates/vm/src/stdlib/winreg.rs +++ b/crates/vm/src/stdlib/winreg.rs @@ -367,10 +367,10 @@ mod winreg { key, wide_sub_key.as_ptr(), args.reserved, - std::ptr::null(), + core::ptr::null(), Registry::REG_OPTION_NON_VOLATILE, args.access, - std::ptr::null(), + core::ptr::null(), &mut res, std::ptr::null_mut(), ) @@ -404,7 +404,9 @@ mod winreg { #[pyfunction] fn DeleteValue(key: PyRef, value: Option, vm: &VirtualMachine) -> PyResult<()> { let wide_value = value.map(|v| v.to_wide_with_nul()); - let value_ptr = wide_value.as_ref().map_or(std::ptr::null(), |v| v.as_ptr()); + let value_ptr = wide_value + .as_ref() + .map_or(core::ptr::null(), |v| v.as_ptr()); let res = unsafe { Registry::RegDeleteValueW(key.hkey.load(), value_ptr) }; if res == 0 { Ok(()) @@ -713,7 +715,7 @@ mod winreg { let res = unsafe { Registry::RegQueryValueExW( target_key, - std::ptr::null(), // NULL value name for default value + core::ptr::null(), // NULL value name for default value std::ptr::null_mut(), &mut reg_type, buffer.as_mut_ptr(), @@ -871,10 +873,10 @@ mod winreg { hkey, wide_sub_key.as_ptr(), 0, - std::ptr::null(), + core::ptr::null(), 0, Registry::KEY_SET_VALUE, - std::ptr::null(), + core::ptr::null(), &mut out_key, std::ptr::null_mut(), ) @@ -893,7 +895,7 @@ mod winreg { let res = unsafe { Registry::RegSetValueExW( target_key, - std::ptr::null(), // value name is NULL + core::ptr::null(), // value name is NULL 0, typ, wide_value.as_ptr() as *const u8, @@ -1104,7 +1106,7 @@ mod winreg { } Ok(None) => { let len = 0; - let ptr = std::ptr::null(); + let ptr = core::ptr::null(); let wide_value_name = value_name.to_wide_with_nul(); let res = unsafe { Registry::RegSetValueExW( diff --git a/crates/vm/src/suggestion.rs b/crates/vm/src/suggestion.rs index 866deb668eb..55326d1d3f0 100644 --- a/crates/vm/src/suggestion.rs +++ b/crates/vm/src/suggestion.rs @@ -7,8 +7,8 @@ use crate::{ exceptions::types::PyBaseException, sliceable::SliceableSequenceOp, }; +use core::iter::ExactSizeIterator; use rustpython_common::str::levenshtein::{MOVE_COST, levenshtein_distance}; -use std::iter::ExactSizeIterator; const MAX_CANDIDATE_ITEMS: usize = 750; diff --git a/crates/vm/src/types/slot.rs b/crates/vm/src/types/slot.rs index 658e21cba8a..9d7d09f2d18 100644 --- a/crates/vm/src/types/slot.rs +++ b/crates/vm/src/types/slot.rs @@ -17,9 +17,9 @@ use crate::{ types::slot_defs::{SlotAccessor, find_slot_defs_by_name}, vm::Context, }; +use core::{any::Any, any::TypeId, borrow::Borrow, cmp::Ordering, ops::Deref}; use crossbeam_utils::atomic::AtomicCell; use num_traits::{Signed, ToPrimitive}; -use std::{any::Any, any::TypeId, borrow::Borrow, cmp::Ordering, ops::Deref}; /// Type-erased storage for extension module data attached to heap types. pub struct TypeDataSlot { @@ -71,7 +71,7 @@ impl<'a, T: Any + 'static> TypeDataRef<'a, T> { } } -impl std::ops::Deref for TypeDataRef<'_, T> { +impl core::ops::Deref for TypeDataRef<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { @@ -96,7 +96,7 @@ impl<'a, T: Any + 'static> TypeDataRefMut<'a, T> { } } -impl std::ops::Deref for TypeDataRefMut<'_, T> { +impl core::ops::Deref for TypeDataRefMut<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { @@ -104,7 +104,7 @@ impl std::ops::Deref for TypeDataRefMut<'_, T> { } } -impl std::ops::DerefMut for TypeDataRefMut<'_, T> { +impl core::ops::DerefMut for TypeDataRefMut<'_, T> { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.guard } @@ -203,8 +203,8 @@ impl PyTypeSlots { } } -impl std::fmt::Debug for PyTypeSlots { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Debug for PyTypeSlots { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_str("PyTypeSlots") } } @@ -1310,7 +1310,7 @@ impl PyType { /// - Special class type handling (e.g., `PyType` and its metaclasses) /// - Post-creation mutations that require `PyRef` #[pyclass] -pub trait Constructor: PyPayload + std::fmt::Debug { +pub trait Constructor: PyPayload + core::fmt::Debug { type Args: FromArgs; /// The type slot for `__new__`. Override this only when you need special @@ -1328,7 +1328,7 @@ pub trait Constructor: PyPayload + std::fmt::Debug { fn py_new(cls: &Py, args: Self::Args, vm: &VirtualMachine) -> PyResult; } -pub trait DefaultConstructor: PyPayload + Default + std::fmt::Debug { +pub trait DefaultConstructor: PyPayload + Default + core::fmt::Debug { fn construct_and_init(args: Self::Args, vm: &VirtualMachine) -> PyResult> where Self: Initializer, @@ -1862,7 +1862,7 @@ pub trait Iterable: PyPayload { fn extend_slots(_slots: &mut PyTypeSlots) {} } -// `Iterator` fits better, but to avoid confusion with rust std::iter::Iterator +// `Iterator` fits better, but to avoid confusion with rust core::iter::Iterator #[pyclass(with(Iterable))] pub trait IterNext: PyPayload + Iterable { #[pyslot] diff --git a/crates/vm/src/utils.rs b/crates/vm/src/utils.rs index af34405c7be..db232e81949 100644 --- a/crates/vm/src/utils.rs +++ b/crates/vm/src/utils.rs @@ -14,15 +14,15 @@ pub fn hash_iter<'a, I: IntoIterator>( vm.state.hash_secret.hash_iter(iter, |obj| obj.hash(vm)) } -impl ToPyObject for std::convert::Infallible { +impl ToPyObject for core::convert::Infallible { fn to_pyobject(self, _vm: &VirtualMachine) -> PyObjectRef { match self {} } } pub trait ToCString: AsRef { - fn to_cstring(&self, vm: &VirtualMachine) -> PyResult { - std::ffi::CString::new(self.as_ref().as_bytes()).map_err(|err| err.to_pyexception(vm)) + fn to_cstring(&self, vm: &VirtualMachine) -> PyResult { + alloc::ffi::CString::new(self.as_ref().as_bytes()).map_err(|err| err.to_pyexception(vm)) } fn ensure_no_nul(&self, vm: &VirtualMachine) -> PyResult<()> { if self.as_ref().as_bytes().contains(&b'\0') { @@ -45,7 +45,7 @@ pub(crate) fn collection_repr<'a, I>( vm: &VirtualMachine, ) -> PyResult where - I: std::iter::Iterator, + I: core::iter::Iterator, { let mut repr = String::new(); if let Some(name) = class_name { diff --git a/crates/vm/src/version.rs b/crates/vm/src/version.rs index 0a598842a56..cc118ee5b0e 100644 --- a/crates/vm/src/version.rs +++ b/crates/vm/src/version.rs @@ -1,7 +1,8 @@ //! Several function to retrieve version information. use chrono::{Local, prelude::DateTime}; -use std::time::{Duration, UNIX_EPOCH}; +use core::time::Duration; +use std::time::UNIX_EPOCH; // = 3.13.0alpha pub const MAJOR: usize = 3; diff --git a/crates/vm/src/vm/context.rs b/crates/vm/src/vm/context.rs index 486c1861fb1..b12352f6eee 100644 --- a/crates/vm/src/vm/context.rs +++ b/crates/vm/src/vm/context.rs @@ -264,7 +264,7 @@ declare_const_name! { // Basic objects: impl Context { - pub const INT_CACHE_POOL_RANGE: std::ops::RangeInclusive = (-5)..=256; + pub const INT_CACHE_POOL_RANGE: core::ops::RangeInclusive = (-5)..=256; const INT_CACHE_POOL_MIN: i32 = *Self::INT_CACHE_POOL_RANGE.start(); pub fn genesis() -> &'static PyRc { @@ -374,14 +374,14 @@ impl Context { #[inline] pub fn empty_tuple_typed(&self) -> &Py> { let py: &Py = &self.empty_tuple; - unsafe { std::mem::transmute(py) } + unsafe { core::mem::transmute(py) } } // universal pyref constructor pub fn new_pyref(&self, value: T) -> PyRef

where T: Into

, - P: PyPayload + std::fmt::Debug, + P: PyPayload + core::fmt::Debug, { value.into().into_ref(self) } diff --git a/crates/vm/src/vm/interpreter.rs b/crates/vm/src/vm/interpreter.rs index 503feb3dc7f..8d37ad6c840 100644 --- a/crates/vm/src/vm/interpreter.rs +++ b/crates/vm/src/vm/interpreter.rs @@ -1,6 +1,6 @@ use super::{Context, PyConfig, VirtualMachine, setting::Settings, thread}; use crate::{PyResult, getpath, stdlib::atexit, vm::PyBaseExceptionRef}; -use std::sync::atomic::Ordering; +use core::sync::atomic::Ordering; /// The general interface for the VM /// diff --git a/crates/vm/src/vm/mod.rs b/crates/vm/src/vm/mod.rs index 34092454059..7c527b3e0da 100644 --- a/crates/vm/src/vm/mod.rs +++ b/crates/vm/src/vm/mod.rs @@ -33,6 +33,11 @@ use crate::{ signal, stdlib, warn::WarningsState, }; +use alloc::borrow::Cow; +use core::{ + cell::{Cell, Ref, RefCell}, + sync::atomic::AtomicBool, +}; use crossbeam_utils::atomic::AtomicCell; #[cfg(unix)] use nix::{ @@ -40,11 +45,8 @@ use nix::{ unistd::getpid, }; use std::{ - borrow::Cow, - cell::{Cell, Ref, RefCell}, collections::{HashMap, HashSet}, ffi::{OsStr, OsString}, - sync::atomic::AtomicBool, }; pub use context::Context; @@ -789,14 +791,14 @@ impl VirtualMachine { pub(crate) fn push_exception(&self, exc: Option) { let mut excs = self.exceptions.borrow_mut(); - let prev = std::mem::take(&mut *excs); + let prev = core::mem::take(&mut *excs); excs.prev = Some(Box::new(prev)); excs.exc = exc } pub(crate) fn pop_exception(&self) -> Option { let mut excs = self.exceptions.borrow_mut(); - let cur = std::mem::take(&mut *excs); + let cur = core::mem::take(&mut *excs); *excs = *cur.prev.expect("pop_exception() without nested exc stack"); cur.exc } @@ -811,7 +813,7 @@ impl VirtualMachine { pub(crate) fn set_exception(&self, exc: Option) { // don't be holding the RefCell guard while __del__ is called - let prev = std::mem::replace(&mut self.exceptions.borrow_mut().exc, exc); + let prev = core::mem::replace(&mut self.exceptions.borrow_mut().exc, exc); drop(prev); } @@ -984,7 +986,7 @@ impl AsRef for VirtualMachine { } fn core_frozen_inits() -> impl Iterator { - let iter = std::iter::empty(); + let iter = core::iter::empty(); macro_rules! ext_modules { ($iter:ident, $($t:tt)*) => { let $iter = $iter.chain(py_freeze!($($t)*)); diff --git a/crates/vm/src/vm/thread.rs b/crates/vm/src/vm/thread.rs index 2e687d99820..7e8f0f87e56 100644 --- a/crates/vm/src/vm/thread.rs +++ b/crates/vm/src/vm/thread.rs @@ -1,10 +1,10 @@ use crate::{AsObject, PyObject, PyObjectRef, VirtualMachine}; -use itertools::Itertools; -use std::{ +use core::{ cell::{Cell, RefCell}, ptr::NonNull, - thread_local, }; +use itertools::Itertools; +use std::thread_local; thread_local! { pub(super) static VM_STACK: RefCell>> = Vec::with_capacity(1).into(); diff --git a/crates/vm/src/vm/vm_new.rs b/crates/vm/src/vm/vm_new.rs index 6d0e983c844..ba09c8ecf69 100644 --- a/crates/vm/src/vm/vm_new.rs +++ b/crates/vm/src/vm/vm_new.rs @@ -95,7 +95,7 @@ impl VirtualMachine { pub fn new_exception(&self, exc_type: PyTypeRef, args: Vec) -> PyBaseExceptionRef { debug_assert_eq!( exc_type.slots.basicsize, - std::mem::size_of::(), + core::mem::size_of::(), "vm.new_exception() is only for exception types without additional payload. The given type '{}' is not allowed.", exc_type.class().name() ); @@ -118,7 +118,7 @@ impl VirtualMachine { errno: Option, msg: impl ToPyObject, ) -> PyRef { - debug_assert_eq!(exc_type.slots.basicsize, std::mem::size_of::()); + debug_assert_eq!(exc_type.slots.basicsize, core::mem::size_of::()); OSErrorBuilder::with_subtype(exc_type, errno, msg, self).build(self) } diff --git a/crates/vm/src/vm/vm_ops.rs b/crates/vm/src/vm/vm_ops.rs index 635fa10e630..1d466984377 100644 --- a/crates/vm/src/vm/vm_ops.rs +++ b/crates/vm/src/vm/vm_ops.rs @@ -302,8 +302,8 @@ impl VirtualMachine { } if let Some(slot_c) = class_c.slots.as_number.left_ternary_op(op_slot) - && slot_a.is_some_and(|slot_a| !std::ptr::fn_addr_eq(slot_a, slot_c)) - && slot_b.is_some_and(|slot_b| !std::ptr::fn_addr_eq(slot_b, slot_c)) + && slot_a.is_some_and(|slot_a| !core::ptr::fn_addr_eq(slot_a, slot_c)) + && slot_b.is_some_and(|slot_b| !core::ptr::fn_addr_eq(slot_b, slot_c)) { let ret = slot_c(a, b, c, self)?; if !ret.is(&self.ctx.not_implemented) { diff --git a/crates/vm/src/warn.rs b/crates/vm/src/warn.rs index 3dbd43ab537..09d48078e56 100644 --- a/crates/vm/src/warn.rs +++ b/crates/vm/src/warn.rs @@ -60,7 +60,7 @@ fn get_warnings_attr( && !vm .state .finalizing - .load(std::sync::atomic::Ordering::SeqCst) + .load(core::sync::atomic::Ordering::SeqCst) { match vm.import("warnings", 0) { Ok(module) => module, diff --git a/crates/vm/src/windows.rs b/crates/vm/src/windows.rs index ccf940811b8..ff2b612c06d 100644 --- a/crates/vm/src/windows.rs +++ b/crates/vm/src/windows.rs @@ -301,7 +301,7 @@ fn win32_xstat_slow_impl(path: &OsStr, traverse: bool) -> std::io::Result std::io::Result std::io::Result PyResult<()> { } #[cfg(feature = "flame-it")] -fn write_profile(settings: &Settings) -> Result<(), Box> { +fn write_profile(settings: &Settings) -> Result<(), Box> { use std::{fs, io}; enum ProfileFormat {