Skip to content

Commit

Permalink
Allow the Manager class to be initialized with peewee.Proxy, see #28
Browse files Browse the repository at this point in the history
  • Loading branch information
rudyryk committed May 31, 2016
1 parent 700b971 commit 70269ab
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
24 changes: 22 additions & 2 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ def __init__(self, database=None, *, loop=None):

self.loop = loop or asyncio.get_event_loop()
self.database = database or self.database
self.database.loop = self.loop
attach_callback = getattr(self.database, 'attach_callback', None)
if attach_callback:
attach_callback(lambda db: db.set_event_loop(self.loop))
else:
self.database.set_event_loop(self.loop)

@property
def is_connected(self):
Expand Down Expand Up @@ -819,6 +823,22 @@ class AsyncDatabase:
_async_wait = None # connection waiter
_task_data = None # task context data

def set_event_loop(self, loop):
"""Set event loop for the database. Usually, you don't need to
call this directly. It's called from `Manager.connect()` or
`.connect_async()` methods.
"""
# These checks are not very pythonic, but I believe it's OK to be
# a little paranoid about mismatching of asyncio event loops,
# because such errors won't show clear traceback and could be
# tricky to debug.
loop = loop or asyncio.get_event_loop()
if not self.loop:
self.loop = loop
elif self.loop != loop:
raise RuntimeError("Error, the event loop is already set before. "
"Make sure you're using the same event loop!")

@asyncio.coroutine
def connect_async(self, loop=None, timeout=None):
"""Set up async connection on specified event loop or
Expand All @@ -833,7 +853,7 @@ def connect_async(self, loop=None, timeout=None):
elif self._async_wait:
yield from self._async_wait
else:
self.loop = loop or asyncio.get_event_loop()
self.set_event_loop(loop)
self._async_wait = asyncio.Future(loop=self.loop)

conn = self._async_conn_cls(
Expand Down
23 changes: 23 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,29 @@ def test_deferred_init(self):
TestModel.create_table(True)
TestModel.drop_table(True)

def test_proxy_database(self):
loop = asyncio.new_event_loop()
database = peewee.Proxy()
TestModel._meta.database = database
objects = peewee_async.Manager(database, loop=loop)

@asyncio.coroutine
def test(objects):
text = "Test %s" % uuid.uuid4()
yield from objects.create(TestModel, text=text)

config = dict(defaults)
for k in list(config.keys()):
config[k].update(overrides.get(k, {}))
database.initialize(db_classes[k](**config[k]))

TestModel.create_table(True)
loop.run_until_complete(test(objects))
loop.run_until_complete(objects.close())
TestModel.drop_table(True)

loop.close()


class OlderTestCase(unittest.TestCase):
# only = ['postgres', 'postgres-ext', 'postgres-pool', 'postgres-pool-ext']
Expand Down

0 comments on commit 70269ab

Please sign in to comment.