Skip to content

Commit

Permalink
trio 0.9
Browse files Browse the repository at this point in the history
  • Loading branch information
RazerM committed Oct 26, 2018
1 parent ddec2d1 commit 1798b83
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 21 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -16,6 +16,9 @@
and `AsyncEngine`.
- Detects attempts to use `event.listen()` with `AsyncConnection` or
`AsyncEngine` and raises a more helpful error message ([#1][]).

### Changed
- Trio support requires trio 0.9+

### Fixed
- `ThreadWorker.quit()` will raise `AlreadyQuit` instead of blocking.
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Expand Up @@ -31,11 +31,11 @@
]

extras_require['test'] = extras_require['test-noextras'] + [
'pytest-trio >= 0.4.1',
'pytest-trio >= 0.5.1',
]

extras_require['trio'] = [
'trio >= 0.3',
'trio >= 0.9',
]


Expand Down
44 changes: 25 additions & 19 deletions sqlalchemy_aio/trio.py
@@ -1,8 +1,9 @@
import trio
import threading
from contextlib import suppress
from functools import partial

import outcome
import trio
from trio import Cancelled, RunFinishedError

from .base import AsyncEngine, ThreadWorker
Expand All @@ -15,14 +16,22 @@ class TrioThreadWorker(ThreadWorker):
def __init__(self, *, branch_from=None):
if branch_from is None:
self._portal = trio.BlockingTrioPortal()
self._request_queue = trio.Queue(1)
self._response_queue = trio.Queue(1)
send_to_thread, receive_from_trio = trio.open_memory_channel(1)
send_to_trio, receive_from_thread = trio.open_memory_channel(1)

self._send_to_thread = send_to_thread
self._send_to_trio = send_to_trio
self._receive_from_trio = receive_from_trio
self._receive_from_thread = receive_from_thread

self._thread = threading.Thread(target=self.thread_fn, daemon=True)
self._thread.start()
else:
self._portal = branch_from._portal
self._request_queue = branch_from._request_queue
self._response_queue = branch_from._response_queue
self._send_to_thread = branch_from._send_to_thread
self._send_to_trio = branch_from._send_to_trio
self._receive_from_trio = branch_from._receive_from_trio
self._receive_from_thread = branch_from._receive_from_thread
self._thread = branch_from._thread

self._branched = branch_from is not None
Expand All @@ -31,19 +40,17 @@ def __init__(self, *, branch_from=None):
def thread_fn(self):
while True:
try:
request = self._portal.run(self._request_queue.get)
except Cancelled:
continue
except RunFinishedError:
request = self._portal.run(self._receive_from_trio.receive)
except (Cancelled, RunFinishedError):
break

if request is not _STOP:
response = outcome.capture(request)
self._portal.run(self._response_queue.put, response)
else:
self._portal.run(self._response_queue.put, None)
except trio.EndOfChannel:
with suppress(Cancelled, RunFinishedError):
self._portal.run(self._send_to_trio.aclose)
break

response = outcome.capture(request)
self._portal.run(self._send_to_trio.send, response)

async def run(self, func, args=(), kwargs=None):
if self._has_quit:
raise AlreadyQuit
Expand All @@ -53,8 +60,8 @@ async def run(self, func, args=(), kwargs=None):
elif args:
func = partial(func, *args)

await self._request_queue.put(func)
resp = await self._response_queue.get()
await self._send_to_thread.send(func)
resp = await self._receive_from_thread.receive()
return resp.unwrap()

async def quit(self):
Expand All @@ -66,8 +73,7 @@ async def quit(self):
if self._branched:
return

await self._request_queue.put(_STOP)
await self._response_queue.get()
await self._send_to_thread.aclose()


class TrioEngine(AsyncEngine):
Expand Down

0 comments on commit 1798b83

Please sign in to comment.