Skip to content

Commit

Permalink
Add public database flag 'allow_sync'
Browse files Browse the repository at this point in the history
  • Loading branch information
rudyryk committed Dec 5, 2015
1 parent 101b0aa commit a47518a
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 33 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## 0.3.3

- Add public `allow_sync` flag to database class, `True` by default
- Remove arguments from `sync_unwanted()` context manager function

## 0.3

- #7, fixed bug with empty result after inserting row with UUID pk
Expand Down
52 changes: 25 additions & 27 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import peewee
import contextlib

__version__ = '0.3.2'
__version__ = '0.3.3'

__all__ = [
# Queries
Expand Down Expand Up @@ -306,8 +306,8 @@ def scalar(query, as_tuple=False):
def cursor_with_query(query):
"""Execute query and return cursor object.
"""
assert query.database.async_conn, "Error, no async database connection."
cursor = yield from query.database.async_conn.cursor()
assert query.database._async_conn, "Error, no async database connection."
cursor = yield from query.database._async_conn.cursor()
yield from cursor.execute(*query.sql())
return cursor

Expand Down Expand Up @@ -442,38 +442,30 @@ class AsyncPostgresqlMixin:
for managing async connection.
"""
def init_async(self, conn_cls=AsyncConnection, **kwargs):
self._sync_unwanted = False
self._loop = None
self.allow_sync = True

self.async_conn = None
self.async_conn_cls = conn_cls
self.async_kwargs = {
self._loop = None
self._async_conn = None
self._async_conn_cls = conn_cls
self._async_kwargs = {
'enable_json': False,
'enable_hstore': False,
}
self.async_kwargs.update(kwargs)
self._async_kwargs.update(kwargs)

@asyncio.coroutine
def connect_async(self, loop=None, timeout=None):
"""Set up async connection on specified event loop or
on default event loop.
"""
if not self.async_conn:
if not self._async_conn:
self._loop = loop if loop else asyncio.get_event_loop()
self.async_conn = self.async_conn_cls(
self._async_conn = self._async_conn_cls(
self._loop,
self.database,
timeout if timeout else aiopg.DEFAULT_TIMEOUT,
**self.async_kwargs)
yield from self.async_conn.connect()

def close_async(self):
"""Close asynchronous connection.
"""
if self.async_conn:
self.async_conn.close()
self.async_conn = None
self._loop = None
**self._async_kwargs)
yield from self._async_conn.connect()

@asyncio.coroutine
def last_insert_id_async(self, cursor, model):
Expand Down Expand Up @@ -503,15 +495,20 @@ def close(self):
"""Close both sync and async connections.
"""
super().close()
self.close_async()

if self._async_conn:
self._async_conn.close()
self._async_conn = None
self._loop = None

def execute_sql(self, *args, **kwargs):
"""Sync execute SQL query. If this query is performing within
`sync_unwanted()` context, then `UnwantedSyncQueryError` exception
is raised.
"""
if self._sync_unwanted:
raise UnwantedSyncQueryError("Error, unwanted sync query", args, kwargs)
if not self.allow_sync:
raise UnwantedSyncQueryError("Error, unwanted sync query",
args, kwargs)
return super().execute_sql(*args, **kwargs)


Expand Down Expand Up @@ -552,13 +549,14 @@ def __init__(self, database, threadlocals=True, autocommit=True,


@contextlib.contextmanager
def sync_unwanted(database, enabled=True):
def sync_unwanted(database):
"""Context manager for preventing unwanted sync queries.
`UnwantedSyncQueryError` exception will raise on such query.
"""
database._sync_unwanted = enabled
old_allow_sync = database.allow_sync
database.allow_sync = False
yield
database._sync_unwanted = False
database.allow_sync = old_allow_sync


class UnwantedSyncQueryError(Exception):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
from setuptools import setup

__version__ = '0.3.2'
__version__ = '0.3.3'

setup(
name="peewee-async",
Expand Down
10 changes: 5 additions & 5 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,18 @@
db_cls = peewee_async.PooledPostgresqlDatabase

database = db_cls(config['db'],
user=config['user'],
password=config['password'],
max_connections=int(config['pool_size']))
user=config['user'],
password=config['password'],
max_connections=int(config['pool_size']))
else:
if config.get('ext', None):
db_cls = peewee_asyncext.PostgresqlExtDatabase
else:
db_cls = peewee_async.PostgresqlDatabase

database = db_cls(config['db'],
user=config['user'],
password=config['password'])
user=config['user'],
password=config['password'])

#
# Tests
Expand Down

0 comments on commit a47518a

Please sign in to comment.