Skip to content

Commit

Permalink
Merge pull request #19 from Downtownapp/master
Browse files Browse the repository at this point in the history
db initialization compatibile with peewee and transactions for pooled conn
  • Loading branch information
rudyryk committed Feb 16, 2016
2 parents b843fe3 + b0afe92 commit 8bc9d34
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 26 deletions.
108 changes: 86 additions & 22 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
"""
import asyncio
import datetime
import uuid

import aiopg
import peewee
import contextlib

from tasklocals import local

__version__ = '0.4.0'

__all__ = [
Expand Down Expand Up @@ -368,21 +371,30 @@ def __init__(self, loop, database, timeout, **kwargs):
self._loop = loop if loop else asyncio.get_event_loop()
self.database = database
self.timeout = timeout
self.dsn, self.connect_kwargs = _compose_dsn(self.database, **kwargs)
self.connect_kwargs = kwargs

@asyncio.coroutine
def get_conn(self):
return self._conn

def release(self, conn):
pass

@asyncio.coroutine
def connect(self):
"""Connect asynchronously.
"""
self._conn = yield from aiopg.connect(
dsn=self.dsn, timeout=self.timeout, loop=self._loop,
timeout=self.timeout, loop=self._loop, database=self.database,
**self.connect_kwargs)

@asyncio.coroutine
def cursor(self, *args, **kwargs):
def cursor(self, conn=None, *args, **kwargs):
"""Get connection cursor asynchronously.
"""
cursor = yield from self._conn.cursor(*args, **kwargs)
if conn is None:
conn = self._conn
cursor = yield from conn.cursor(*args, **kwargs)
cursor.release = lambda: None
return cursor

Expand All @@ -391,6 +403,10 @@ def close(self):
"""
self._conn.close()

@asyncio.coroutine
def close_async(self):
pass


class PooledAsyncConnection:
"""
Expand All @@ -401,34 +417,53 @@ def __init__(self, loop, database, timeout, **kwargs):
self._loop = loop if loop else asyncio.get_event_loop()
self.database = database
self.timeout = timeout
self.dsn, self.connect_kwargs = _compose_dsn(self.database, **kwargs)
self.connect_kwargs = kwargs

@asyncio.coroutine
def get_conn(self):
return (yield from self._pool.acquire())

def release(self, conn):
self._pool.release(conn)

@asyncio.coroutine
def connect(self):
"""Create connection pool asynchronously.
"""
self._pool = yield from aiopg.create_pool(
dsn=self.dsn,
loop=self._loop,
timeout=self.timeout,
database=self.database,
**self.connect_kwargs)

@asyncio.coroutine
def cursor(self, *args, **kwargs):
def cursor(self, conn=None, *args, **kwargs):
"""Get cursor for connection from pool.
"""
conn = yield from self._pool.acquire()
cursor = yield from conn.cursor(*args, **kwargs)
cursor.release = lambda: all((cursor.close(), self._pool.release(conn)))
if conn is None:
conn = yield from self._pool.acquire()
cursor = yield from conn.cursor(*args, **kwargs)

def releaser():
cursor.close()
self._pool.release(conn)
cursor.release = releaser
else:
cursor = yield from conn.cursor(*args, **kwargs)
cursor.release = lambda: cursor.close()
return cursor

def close(self):
"""Terminate all pool connections.
"""
self._pool.terminate()

@asyncio.coroutine
def close_async(self):
yield from self._pool.wait_closed()


class transaction(peewee._callable_context_manager):
class transaction:
"""Asynchronous context manager (`async with`), similar to
`peewee.transaction()`.
"""
Expand Down Expand Up @@ -476,7 +511,7 @@ def __aexit__(self, exc_type, exc_val, exc_tb):
self.db.pop_transaction()


class savepoint(peewee._callable_context_manager):
class savepoint:
"""Asynchronous context manager (`async with`), similar to
`peewee.savepoint()`.
"""
Expand Down Expand Up @@ -520,7 +555,7 @@ def __aexit__(self, exc_type, exc_val, exc_tb):
self.db.set_autocommit(self._orig_autocommit)


class atomic(peewee._callable_context_manager):
class atomic:
"""Asynchronous context manager (`async with`), similar to
`peewee.atomic()`.
"""
Expand All @@ -533,20 +568,24 @@ def __aenter__(self):
self._helper = self.db.transaction_async()
else:
self._helper = self.db.savepoint_async()
yield from self._helper.__aenter__()
self.db.locals.transaction_conn = yield from self.db._async_conn.get_conn()
return (yield from self._helper.__aenter__())

@asyncio.coroutine
def __aexit__(self, exc_type, exc_val, exc_tb):
yield from self._helper.__aexit__(exc_type, exc_val, exc_tb)
self.db._async_conn.release(self.db.locals.transaction_conn)
self.db.locals.transaction_conn = None


class AsyncPostgresqlMixin:
"""Mixin for peewee database class providing extra methods
for managing async connection.
"""
def init_async(self, conn_cls=AsyncConnection, **kwargs):
def init_async(self, database, conn_cls=AsyncConnection, **kwargs):
self.allow_sync = True

self.deferred = database is None
self.database = database
self._loop = None
self._async_conn = None
self._async_conn_cls = conn_cls
Expand All @@ -555,12 +594,21 @@ def init_async(self, conn_cls=AsyncConnection, **kwargs):
'enable_hstore': False,
}
self._async_kwargs.update(kwargs)
self.connect_kwargs = kwargs.copy()
self.connect_kwargs.pop('enable_json', None)
self.connect_kwargs.pop('enable_hstore', None)

self.locals = local()

@asyncio.coroutine
def connect_async(self, loop=None, timeout=None):
"""Set up async connection on specified event loop or
on default event loop.
"""
if self.deferred:
raise Exception('Error, database not properly initialized '
'before opening connection')

if not self._async_conn:
self._loop = loop if loop else asyncio.get_event_loop()
self._async_conn = self._async_conn_cls(
Expand Down Expand Up @@ -594,6 +642,14 @@ def last_insert_id_async(self, cursor, model):
result = (yield from cursor.fetchone())[0]
return result

@asyncio.coroutine
def close_async(self, loop=None):
self.close()
yield from self._async_conn.close_async()
if self._async_conn:
self._async_conn = None
self._loop = None

def atomic_async(self):
"""Similar to peewee `Database.atomic()` method, but returns
asynchronous context manager.
Expand All @@ -619,8 +675,6 @@ def close(self):

if self._async_conn:
self._async_conn.close()
self._async_conn = None
self._loop = None

def execute_sql(self, *args, **kwargs):
"""Sync execute SQL query. If this query is performing within
Expand All @@ -633,6 +687,13 @@ def execute_sql(self, *args, **kwargs):
return super().execute_sql(*args, **kwargs)


class PooledAsyncPostgresqlMixin(AsyncPostgresqlMixin):
def init_async(self, database, conn_cls=PooledAsyncConnection, minsize=1, maxsize=20, **kwargs):
super(PooledAsyncPostgresqlMixin, self).init_async(database, conn_cls, **kwargs)
self._async_kwargs['minsize'] = minsize
self._async_kwargs['maxsize'] = maxsize


class PostgresqlDatabase(AsyncPostgresqlMixin, peewee.PostgresqlDatabase):
"""PosgreSQL database driver providing **single drop-in sync** connection
and **single async connection** interface.
Expand All @@ -646,10 +707,10 @@ def __init__(self, database, threadlocals=True, autocommit=True,
fields=fields, ops=ops, autorollback=autorollback,
**kwargs)

self.init_async(**self.connect_kwargs)
self.init_async(database, **self.connect_kwargs)


class PooledPostgresqlDatabase(AsyncPostgresqlMixin, peewee.PostgresqlDatabase):
class PooledPostgresqlDatabase(PooledAsyncPostgresqlMixin, peewee.PostgresqlDatabase):
"""PosgreSQL database driver providing **single drop-in sync**
connection and **async connections pool** interface.
Expand All @@ -665,7 +726,7 @@ def __init__(self, database, threadlocals=True, autocommit=True,
fields=fields, ops=ops, autorollback=autorollback,
**kwargs)

self.init_async(conn_cls=PooledAsyncConnection, minsize=1,
self.init_async(database, conn_cls=PooledAsyncConnection, minsize=1,
maxsize=max_connections, **self.connect_kwargs)


Expand All @@ -691,7 +752,10 @@ def _run_sql(db, operation, *args, **kwargs):
"""Run SQL operation (query or command) against database.
"""
assert db._async_conn, "Error, no async database connection."
cursor = yield from db._async_conn.cursor()
if getattr(db.locals, 'transaction_conn', None) is not None:
cursor = yield from db._async_conn.cursor(conn=db.locals.transaction_conn)
else:
cursor = yield from db._async_conn.cursor()
try:
yield from cursor.execute(operation, *args, **kwargs)
except Exception as e:
Expand Down
8 changes: 4 additions & 4 deletions peewee_asyncext.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Copyright (c) 2014, Alexey Kinev <rudy@05bit.com>
"""
from peewee_async import AsyncPostgresqlMixin, PooledAsyncConnection
from peewee_async import AsyncPostgresqlMixin, PooledAsyncConnection, PooledAsyncPostgresqlMixin
import playhouse.postgres_ext as ext


Expand All @@ -35,10 +35,10 @@ def __init__(self, database, threadlocals=True, autocommit=True,
'enable_json': True,
'enable_hstore': self.register_hstore,
})
self.init_async(**async_kwargs)
self.init_async(database, **async_kwargs)


class PooledPostgresqlExtDatabase(AsyncPostgresqlMixin, ext.PostgresqlExtDatabase):
class PooledPostgresqlExtDatabase(PooledAsyncPostgresqlMixin, ext.PostgresqlExtDatabase):
"""PosgreSQL database extended driver providing **single drop-in sync**
connection and **async connections pool** interface.
Expand All @@ -59,5 +59,5 @@ def __init__(self, database, threadlocals=True, autocommit=True,
'enable_json': True,
'enable_hstore': self.register_hstore,
})
self.init_async(conn_cls=PooledAsyncConnection, minsize=1,
self.init_async(database, conn_cls=PooledAsyncConnection, minsize=1,
maxsize=max_connections, **async_kwargs)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
install_requires=(
'peewee>=2.6.4',
'aiopg>=0.7.0',
'tasklocals>=0.2',
),
py_modules=[
'peewee_async',
Expand Down
8 changes: 8 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def setUpModule():
if 'max_connections' in db_cfg:
db_cfg['max_connections'] = int(db_cfg['max_connections'])
use_pool = db_cfg['max_connections'] > 1
if not use_pool:
db_cfg.pop('max_connections')
else:
use_pool = False

Expand Down Expand Up @@ -189,6 +191,12 @@ def tearDownClass(cls, *args, **kwargs):

# Close database
database.close()
# Async connect
cls.loop = asyncio.get_event_loop()
@asyncio.coroutine
def close():
yield from database.close_async(loop=cls.loop)
cls.loop.run_until_complete(close())

def run_until_complete(self, coroutine):
result = self.loop.run_until_complete(coroutine)
Expand Down

0 comments on commit 8bc9d34

Please sign in to comment.