Skip to content

Commit

Permalink
Fix #30: allow_sync() context manager didn't work with peewee.Proxy; …
Browse files Browse the repository at this point in the history
…also database's .allow_sync converted to context manager; database's .allow_sync setter marked as deprecated
  • Loading branch information
rudyryk committed Jun 1, 2016
1 parent 70269ab commit effe078
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 99 deletions.
143 changes: 90 additions & 53 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,23 @@ def __init__(self, database=None, *, loop=None):
("Error, database must be provided via "
"argument or class member.")

self.loop = loop or asyncio.get_event_loop()
self.database = database or self.database
self._loop = loop

attach_callback = getattr(self.database, 'attach_callback', None)
if attach_callback:
attach_callback(lambda db: db.set_event_loop(self.loop))
attach_callback(lambda db: setattr(db, '_loop', loop))
else:
self.database.set_event_loop(self.loop)
self.database._loop = loop

@property
def loop(self):
"""Get the event loop.
If no event loop is provided explicitly on creating
the instance, just return the current event loop.
"""
return self._loop or asyncio.get_event_loop()

@property
def is_connected(self):
Expand All @@ -123,7 +133,7 @@ def is_connected(self):

@asyncio.coroutine
def get(self, source, *args, **kwargs):
"""Get model instance.
"""Get the model instance.
:param source: model or base query for lookup
Expand Down Expand Up @@ -159,7 +169,7 @@ async def my_async_func():

@asyncio.coroutine
def create(self, model, **data):
"""Create new object saved to database.
"""Create a new object saved to database.
"""
inst = model(**data)
query = model.insert(**dict(inst._data))
Expand All @@ -174,7 +184,7 @@ def create(self, model, **data):

@asyncio.coroutine
def get_or_create(self, model, defaults=None, **kwargs):
"""Try to get object or create it with specified defaults.
"""Try to get an object or create it with the specified defaults.
Return 2-tuple containing the model instance and a boolean
indicating whether the instance was created.
Expand All @@ -189,8 +199,8 @@ def get_or_create(self, model, defaults=None, **kwargs):

@asyncio.coroutine
def update(self, obj, only=None):
"""Update object in database. Optionally, update only specified
fields. For creating new object use :meth:`.create()`
"""Update the object in the database. Optionally, update only
the specified fields. For creating a new object use :meth:`.create()`
:param only: (optional) the list/tuple of fields or
field names to update
Expand Down Expand Up @@ -322,24 +332,16 @@ def savepoint(self, sid=None):
"""
return savepoint(self.database, sid=sid)

@contextlib.contextmanager
def allow_sync(self):
"""Allow sync queries within context. Close sync
connection on exit if connected.
"""Allow sync queries within context. Close the sync
database connection on exit if connected.
Example::
with objects.allow_sync():
PageBlock.create_table(True)
"""
old_allow = self.database.allow_sync
self.database.allow_sync = True
yield
try:
self.database.close()
except self.database.Error:
pass # already closed
self.database.allow_sync = old_allow
return self.database.allow_sync()

def _swap_database(self, query):
"""Swap database for query if swappable. Return **new query**
Expand Down Expand Up @@ -817,27 +819,30 @@ def _get_result_wrapper(self, query):
############

class AsyncDatabase:
allow_sync = True # whether sync queries allowed
loop = None # asyncio event loop
_loop = None # asyncio event loop
_allow_sync = True # whether sync queries are allowed
_async_conn = None # async connection
_async_wait = None # connection waiter
_task_data = None # task context data
_task_data = None # asyncio per-task data

def __setattr__(self, name, value):
if name == 'allow_sync':
warnings.warn(
"`.allow_sync` setter is deprecated, use either the "
"`.allow_sync()` context manager or `.set_allow_sync()` "
"method.", DeprecationWarning)
self._allow_sync = value
else:
super().__setattr__(name, value)

@property
def loop(self):
"""Get the event loop.
def set_event_loop(self, loop):
"""Set event loop for the database. Usually, you don't need to
call this directly. It's called from `Manager.connect()` or
`.connect_async()` methods.
If no event loop is provided explicitly on creating
the instance, just return the current event loop.
"""
# These checks are not very pythonic, but I believe it's OK to be
# a little paranoid about mismatching of asyncio event loops,
# because such errors won't show clear traceback and could be
# tricky to debug.
loop = loop or asyncio.get_event_loop()
if not self.loop:
self.loop = loop
elif self.loop != loop:
raise RuntimeError("Error, the event loop is already set before. "
"Make sure you're using the same event loop!")
return self._loop or asyncio.get_event_loop()

@asyncio.coroutine
def connect_async(self, loop=None, timeout=None):
Expand All @@ -853,12 +858,12 @@ def connect_async(self, loop=None, timeout=None):
elif self._async_wait:
yield from self._async_wait
else:
self.set_event_loop(loop)
self._async_wait = asyncio.Future(loop=self.loop)
self._loop = loop
self._async_wait = asyncio.Future(loop=self._loop)

conn = self._async_conn_cls(
database=self.database,
loop=self.loop,
loop=self._loop,
timeout=timeout,
**self.connect_kwargs_async)

Expand All @@ -869,15 +874,15 @@ def connect_async(self, loop=None, timeout=None):
self._async_wait = None
raise
else:
self._task_data = TaskLocals(loop=self.loop)
self._task_data = TaskLocals(loop=self._loop)
self._async_conn = conn
self._async_wait.set_result(True)

@asyncio.coroutine
def cursor_async(self):
"""Acquire async cursor.
"""
yield from self.connect_async(loop=self.loop)
yield from self.connect_async(loop=self._loop)

if self.transaction_depth_async() > 0:
conn = self.transaction_conn_async()
Expand Down Expand Up @@ -956,15 +961,47 @@ def savepoint_async(self, sid=None):
"""
return savepoint(self, sid=sid)

def set_allow_sync(self, value):
"""Allow or forbid sync queries for the database. See also
the :meth:`.allow_sync()` context manager.
"""
self._allow_sync = value

@contextlib.contextmanager
def allow_sync(self):
"""Allow sync queries within context. Close sync
connection on exit if connected.
Example::
with database.allow_sync():
PageBlock.create_table(True)
"""
old_allow_sync = self._allow_sync
self._allow_sync = True

try:
yield
except:
raise
finally:
try:
self.close()
except self.Error:
pass # already closed

self._allow_sync = old_allow_sync

def execute_sql(self, *args, **kwargs):
"""Sync execute SQL query, `allow_sync` must be set to True.
"""
assert self.allow_sync, ("Error, sync query is not allowed: "
"allow_sync is False")
if self.allow_sync in (logging.ERROR, logging.WARNING):
logging.log(self.allow_sync,
"Error, sync query is not allowed: %s %s" %
str(args), str(kwargs))
assert self._allow_sync, (
"Error, sync query is not allowed! Call the `.set_allow_sync()` "
"or use the `.allow_sync()` context manager.")
if self._allow_sync in (logging.ERROR, logging.WARNING):
logging.log(self._allow_sync,
"Error, sync query is not allowed: %s %s" %
str(args), str(kwargs))
return super().execute_sql(*args, **kwargs)


Expand Down Expand Up @@ -1268,17 +1305,17 @@ def sync_unwanted(database):
`UnwantedSyncQueryError` exception will raise on such query.
NOTE: sync_unwanted() context manager is **deprecated**, use
database `allow_sync` property directly or via `Manager.allow_sync()`
database's `.allow_sync()` context manager or `Manager.allow_sync()`
context manager.
"""
warnings.warn("sync_unwanted() context manager is deprecated, "
"use database `allow_sync` property directly or "
"via Manager `allow_sync()` context manager. ",
"use database's `.allow_sync()` context manager or "
"`Manager.allow_sync()` context manager. ",
DeprecationWarning)
old_allow_sync = database.allow_sync
database.allow_sync = False
old_allow_sync = database._allow_sync
database._allow_sync = False
yield
database.allow_sync = old_allow_sync
database._allow_sync = old_allow_sync


class UnwantedSyncQueryError(Exception):
Expand Down

0 comments on commit effe078

Please sign in to comment.