Skip to content

Commit

Permalink
Internals refinements; database swapping feature hidden from public i…
Browse files Browse the repository at this point in the history
…nterface
  • Loading branch information
rudyryk committed Apr 17, 2016
1 parent 802c1fc commit 5111be7
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 78 deletions.
153 changes: 80 additions & 73 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,11 @@
### High level API ###

'Manager',
'AutoDatabase',
'PostgresqlDatabase',
'PooledPostgresqlDatabase',

### Low level API ###

'PostgresqlDatabase',
'PooledPostgresqlDatabase',
'execute',
'get_object',
'create_object',
Expand Down Expand Up @@ -98,11 +97,12 @@ class MyManager(Manager):

def __init__(self, database=None, *, loop=None):
assert database or self.database, \
("Error, database should be provided via "
("Error, database must be provided via "
"argument or class member.")

self.loop = loop or asyncio.get_event_loop()
self.database = database or self.database
self.database.allow_sync = False
self.database.loop = self.loop

@property
def is_connected(self):
Expand Down Expand Up @@ -235,7 +235,7 @@ def create_or_get(self, model, **kwargs):
def execute(self, query):
"""Execute query asyncronously.
"""
query = yield from self._prepare_query(query)
query = self._swap_database(query)
return (yield from execute(query))

@asyncio.coroutine
Expand All @@ -244,7 +244,7 @@ def prefetch(self, query, *subqueries):
:return: Query that has already cached data for subqueries
"""
query = yield from self._prepare_query(query)
query = self._swap_database(query)
subqueries = map(self._swap_database, subqueries)
return (yield from prefetch(query, *subqueries))

Expand All @@ -254,7 +254,7 @@ def count(self, query, clear_limit=False):
:return: number of objects in ``select()`` query
"""
query = yield from self._prepare_query(query)
query = self._swap_database(query)
return (yield from count(query, clear_limit=clear_limit))

@asyncio.coroutine
Expand All @@ -263,24 +263,9 @@ def scalar(self, query, as_tuple=False):
:return: result is the same as after sync ``query.scalar()`` call
"""
query = yield from self._prepare_query(query)
query = self._swap_database(query)
return (yield from scalar(query, as_tuple=as_tuple))

# @asyncio.coroutine
def atomic(self):
# yield from self.connect()
return atomic(self.database)

@asyncio.coroutine
def transaction(self):
yield from self.connect()
return transaction(self.database)

@asyncio.coroutine
def savepoint(self, sid=None):
yield from self.connect()
return savepoint(self.database, sid=sid)

@asyncio.coroutine
def connect(self):
"""Open database async connection if not connected.
Expand All @@ -293,6 +278,15 @@ def close(self):
"""
yield from self.database.close_async()

def atomic(self):
return atomic(self.database)

def transaction(self):
return transaction(self.database)

def savepoint(self, sid=None):
return savepoint(self.database, sid=sid)

@contextlib.contextmanager
def allow_sync(self, allow=True):
"""Allow sync queries within context.
Expand All @@ -302,14 +296,6 @@ def allow_sync(self, allow=True):
yield
self.database.allow_sync = old_allow

@asyncio.coroutine
def _prepare_query(self, query):
"""Connect to database if not connected and swap database
for query to execute against manager's database.
"""
yield from self.connect()
return self._swap_database(query)

def _swap_database(self, query):
"""Swap database for query if swappable. Return **new query**
with swapped database.
Expand All @@ -321,7 +307,7 @@ def _swap_database(self, query):
database, it's **WRONG AND DANGEROUS**, so assertion is raised.
"""
model = query.model_class
if model._meta.database == AutoDatabase:
if model._meta.database == _AutoDatabase:
query = query.clone()
query.database = self.database
return query
Expand All @@ -347,8 +333,9 @@ def _prune_fields(field_dict, only):
return field_dict


class AutoDatabase:
"""Swappable database placeholder. Doesn't contain any implementation.
class _AutoDatabase:
"""Experimental swappable database placeholder.
Doesn't contain any implementation details.
"""
# Both PostgreSQL and MySQL need commiting SELECT
commit_select = True
Expand Down Expand Up @@ -748,26 +735,30 @@ def fetchone(self):
class AsyncPooledPostgresqlConnection:
"""Asynchronous database connection pool wrapper.
"""
def __init__(self, loop, database, timeout, **kwargs):
self._pool = None
self._loop = loop if loop else asyncio.get_event_loop()
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
self.pool = None
self.loop = loop
self.database = database
self.timeout = timeout
self.connect_kwargs = kwargs

@asyncio.coroutine
def acquire(self):
return (yield from self._pool.acquire())
"""Acquire connection from pool.
"""
return (yield from self.pool.acquire())

def release(self, conn):
self._pool.release(conn)
"""Release connection to pool.
"""
self.pool.release(conn)

@asyncio.coroutine
def connect(self):
"""Create connection pool asynchronously.
"""
self._pool = yield from aiopg.create_pool(
loop=self._loop,
self.pool = yield from aiopg.create_pool(
loop=self.loop,
timeout=self.timeout,
database=self.database,
**self.connect_kwargs)
Expand All @@ -785,7 +776,7 @@ def cursor(self, conn=None, *args, **kwargs):

def release():
cursor.close()
self._pool.release(conn)
self.pool.release(conn)
cursor.release = release
else:
# Acquire cursor from provided connection, after cursor is
Expand All @@ -801,17 +792,18 @@ def release():
def close(self):
"""Terminate all pool connections.
"""
self._pool.terminate()
yield from self._pool.wait_closed()
self.pool.terminate()
yield from self.pool.wait_closed()


class AsyncPostgresqlConnection(AsyncPooledPostgresqlConnection):
"""Asynchronous single database connection wrapper.
"""
def __init__(self, loop, database, timeout, **kwargs):
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
kwargs['minsize'] = 1
kwargs['maxsize'] = 1
super().__init__(loop, database, timeout, **kwargs)
super().__init__(database=database, loop=loop, timeout=timeout,
**kwargs)


class AsyncPostgresqlMixin:
Expand All @@ -823,10 +815,11 @@ def init_async(self, conn_cls=AsyncPostgresqlConnection, enable_json=False,
if not aiopg:
raise Exception("Error, aiopg is not installed!")

self.allow_sync = True
self._loop = None
self.allow_sync = True
self.loop = None
self._async_conn = None
self._async_conn_cls = conn_cls
self._async_wait = None
self._enable_json = enable_json
self._enable_hstore = enable_hstore

Expand All @@ -850,25 +843,52 @@ def connect_async(self, loop=None, timeout=None):
raise Exception("Error, database not properly initialized "
"before opening connection")

if not self._async_conn:
loop = loop or asyncio.get_event_loop()
if self._async_wait:
yield from self._async_wait
else:
self.loop = loop or asyncio.get_event_loop()
self._async_wait = asyncio.Future(loop=self.loop)

conn = self._async_conn_cls(
loop, self.database,
timeout if timeout else aiopg.DEFAULT_TIMEOUT,
database=self.database,
loop=self.loop,
timeout=(timeout or aiopg.DEFAULT_TIMEOUT),
**self.connect_kwargs_async)

yield from conn.connect()

self._loop = loop
self._task_data = tasklocals.local(loop=loop)
self._task_data = tasklocals.local(loop=self.loop)
self._async_conn = conn
self._async_wait.set_result(True)

@asyncio.coroutine
def cursor_async(self):
"""Acquire async cursor.
"""
if not self._async_conn:
yield from self.connect_async(loop=self.loop)

if self.transaction_depth_async() > 0:
conn = self.transaction_conn_async()
else:
conn = None

try:
return (yield from self._async_conn.cursor(conn=conn))
except:
yield from self.close_async()
raise

@asyncio.coroutine
def close_async(self):
"""Close async connection.
"""
if self._async_wait:
yield from self._async_wait
if self._async_conn:
conn = self._async_conn
self._async_conn = None
self._loop = None
self._async_wait = None
yield from conn.close()

@asyncio.coroutine
Expand Down Expand Up @@ -899,10 +919,11 @@ def last_insert_id_async(self, cursor, model):
def push_transaction_async(self):
"""Increment async transaction depth.
"""
if not self._async_conn:
yield from self.connect_async(loop=self.loop)
if not getattr(self._task_data, 'depth', 0):
self._task_data.depth = 0
self._task_data.conn = yield from self._async_conn.acquire()

self._task_data.depth += 1

@asyncio.coroutine
Expand All @@ -911,7 +932,6 @@ def pop_transaction_async(self):
"""
if self._task_data.depth > 0:
self._task_data.depth -= 1

if self._task_data.depth == 0:
self._async_conn.release(self._task_data.conn)
self._task_data.conn = None
Expand Down Expand Up @@ -1145,21 +1165,8 @@ def __aexit__(self, exc_type, exc_val, exc_tb):
@asyncio.coroutine
def _run_sql(db, operation, *args, **kwargs):
"""Run SQL operation (query or command) against database.
"""
if not db._async_conn:
raise peewee.DatabaseError("Error, async database connection "
"is not set up.")

if db.transaction_depth_async() > 0:
conn = db.transaction_conn_async()
else:
conn = None

try:
cursor = yield from db._async_conn.cursor(conn=conn)
except:
yield from db.close_async()
raise
"""
cursor = yield from db.cursor_async()

try:
yield from cursor.execute(operation, *args, **kwargs)
Expand Down
10 changes: 5 additions & 5 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,38 +86,38 @@ class TestModel(peewee.Model):
text = peewee.CharField()

class Meta:
database = peewee_async.AutoDatabase
database = peewee_async._AutoDatabase


class TestModelAlpha(peewee.Model):
text = peewee.CharField()

class Meta:
database = peewee_async.AutoDatabase
database = peewee_async._AutoDatabase


class TestModelBeta(peewee.Model):
alpha = peewee.ForeignKeyField(TestModelAlpha, related_name='betas')
text = peewee.CharField()

class Meta:
database = peewee_async.AutoDatabase
database = peewee_async._AutoDatabase


class TestModelGamma(peewee.Model):
text = peewee.CharField()
beta = peewee.ForeignKeyField(TestModelBeta, related_name='gammas')

class Meta:
database = peewee_async.AutoDatabase
database = peewee_async._AutoDatabase


class UUIDTestModel(peewee.Model):
id = peewee.UUIDField(primary_key=True, default=uuid.uuid4)
text = peewee.CharField()

class Meta:
database = peewee_async.AutoDatabase
database = peewee_async._AutoDatabase


####################
Expand Down

0 comments on commit 5111be7

Please sign in to comment.