diff --git a/CHANGELOG.md b/CHANGELOG.md index e152cb1..c30c160 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,9 @@ and `AsyncEngine`. - Detects attempts to use `event.listen()` with `AsyncConnection` or `AsyncEngine` and raises a more helpful error message ([#1][]). + +### Changed +- Trio support requires trio 0.9+ ### Fixed - `ThreadWorker.quit()` will raise `AlreadyQuit` instead of blocking. diff --git a/setup.py b/setup.py index f7fb058..27862c6 100644 --- a/setup.py +++ b/setup.py @@ -31,11 +31,11 @@ ] extras_require['test'] = extras_require['test-noextras'] + [ - 'pytest-trio >= 0.4.1', + 'pytest-trio >= 0.5.1', ] extras_require['trio'] = [ - 'trio >= 0.3', + 'trio >= 0.9', ] diff --git a/sqlalchemy_aio/trio.py b/sqlalchemy_aio/trio.py index 449c43f..5ade876 100644 --- a/sqlalchemy_aio/trio.py +++ b/sqlalchemy_aio/trio.py @@ -1,8 +1,9 @@ -import trio import threading +from contextlib import suppress from functools import partial import outcome +import trio from trio import Cancelled, RunFinishedError from .base import AsyncEngine, ThreadWorker @@ -15,14 +16,22 @@ class TrioThreadWorker(ThreadWorker): def __init__(self, *, branch_from=None): if branch_from is None: self._portal = trio.BlockingTrioPortal() - self._request_queue = trio.Queue(1) - self._response_queue = trio.Queue(1) + send_to_thread, receive_from_trio = trio.open_memory_channel(1) + send_to_trio, receive_from_thread = trio.open_memory_channel(1) + + self._send_to_thread = send_to_thread + self._send_to_trio = send_to_trio + self._receive_from_trio = receive_from_trio + self._receive_from_thread = receive_from_thread + self._thread = threading.Thread(target=self.thread_fn, daemon=True) self._thread.start() else: self._portal = branch_from._portal - self._request_queue = branch_from._request_queue - self._response_queue = branch_from._response_queue + self._send_to_thread = branch_from._send_to_thread + self._send_to_trio = branch_from._send_to_trio + self._receive_from_trio = branch_from._receive_from_trio + self._receive_from_thread = branch_from._receive_from_thread self._thread = branch_from._thread self._branched = branch_from is not None @@ -31,19 +40,17 @@ def __init__(self, *, branch_from=None): def thread_fn(self): while True: try: - request = self._portal.run(self._request_queue.get) - except Cancelled: - continue - except RunFinishedError: + request = self._portal.run(self._receive_from_trio.receive) + except (Cancelled, RunFinishedError): break - - if request is not _STOP: - response = outcome.capture(request) - self._portal.run(self._response_queue.put, response) - else: - self._portal.run(self._response_queue.put, None) + except trio.EndOfChannel: + with suppress(Cancelled, RunFinishedError): + self._portal.run(self._send_to_trio.aclose) break + response = outcome.capture(request) + self._portal.run(self._send_to_trio.send, response) + async def run(self, func, args=(), kwargs=None): if self._has_quit: raise AlreadyQuit @@ -53,8 +60,8 @@ async def run(self, func, args=(), kwargs=None): elif args: func = partial(func, *args) - await self._request_queue.put(func) - resp = await self._response_queue.get() + await self._send_to_thread.send(func) + resp = await self._receive_from_thread.receive() return resp.unwrap() async def quit(self): @@ -66,8 +73,7 @@ async def quit(self): if self._branched: return - await self._request_queue.put(_STOP) - await self._response_queue.get() + await self._send_to_thread.aclose() class TrioEngine(AsyncEngine):