Skip to content

Commit

Permalink
Add test for deferred database init, clean up some mess with async co…
Browse files Browse the repository at this point in the history
…nnect and close methods
  • Loading branch information
rudyryk committed Feb 19, 2016
1 parent 8bc9d34 commit f2b4c52
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 68 deletions.
46 changes: 9 additions & 37 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,15 +398,12 @@ def cursor(self, conn=None, *args, **kwargs):
cursor.release = lambda: None
return cursor

@asyncio.coroutine
def close(self):
"""Close connection.
"""
self._conn.close()

@asyncio.coroutine
def close_async(self):
pass


class PooledAsyncConnection:
"""
Expand Down Expand Up @@ -453,13 +450,11 @@ def releaser():
cursor.release = lambda: cursor.close()
return cursor

@asyncio.coroutine
def close(self):
"""Terminate all pool connections.
"""
self._pool.terminate()

@asyncio.coroutine
def close_async(self):
yield from self._pool.wait_closed()


Expand Down Expand Up @@ -618,6 +613,13 @@ def connect_async(self, loop=None, timeout=None):
**self._async_kwargs)
yield from self._async_conn.connect()

@asyncio.coroutine
def close_async(self):
if self._async_conn:
yield from self._async_conn.close()
self._async_conn = None
self._loop = None

@asyncio.coroutine
def last_insert_id_async(self, cursor, model):
"""Get ID of last inserted row.
Expand All @@ -642,14 +644,6 @@ def last_insert_id_async(self, cursor, model):
result = (yield from cursor.fetchone())[0]
return result

@asyncio.coroutine
def close_async(self, loop=None):
self.close()
yield from self._async_conn.close_async()
if self._async_conn:
self._async_conn = None
self._loop = None

def atomic_async(self):
"""Similar to peewee `Database.atomic()` method, but returns
asynchronous context manager.
Expand All @@ -668,14 +662,6 @@ def savepoint_async(self, sid=None):
"""
return savepoint(self, sid=sid)

def close(self):
"""Close both sync and async connections.
"""
super().close()

if self._async_conn:
self._async_conn.close()

def execute_sql(self, *args, **kwargs):
"""Sync execute SQL query. If this query is performing within
`sync_unwanted()` context, then `UnwantedSyncQueryError` exception
Expand Down Expand Up @@ -772,20 +758,6 @@ def _execute_query_async(query):
return (yield from _run_sql(db, *query.sql()))


def _compose_dsn(dbname, **kwargs):
"""Compose DSN string by set of connection parameters.
Extract parameters: dbname, user, password, host, port.
Return DSN string and remain parameters dict.
"""
dsn = 'dbname=%s' % dbname
for k in ('user', 'password', 'host', 'port'):
v = kwargs.pop(k, None)
if v:
dsn += ' %s=%s' % (k, v)
return dsn, kwargs


@asyncio.coroutine
def prefetch(sq, *subqueries):
"""Asynchronous version of the prefetch function from peewee.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
license='MIT',
zip_safe=False,
install_requires=(
'peewee>=2.6.4',
'peewee>=2.8.0',
'aiopg>=0.7.0',
'tasklocals>=0.2',
),
Expand Down
69 changes: 39 additions & 30 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,21 @@ def __setattr__(self, attr, value):
sync_unwanted = peewee_async.sync_unwanted

# Globals
config = {}
db_params = {}
database = ProxyDatabase()


def setUpModule():
global config
global db_params
global database

ini_config = configparser.ConfigParser()
ini_config.read(['tests.ini'])
ini = configparser.ConfigParser()
ini.read(['tests.ini'])

try:
config = dict(**ini_config['tests'])
config = dict(**ini['tests'])
except KeyError:
pass
config = {}

config.setdefault('database', 'test')
config.setdefault('host', '127.0.0.1')
Expand All @@ -78,16 +78,15 @@ def setUpModule():
config['host'] = url.host or config['host']
config['port'] = url.port or config['port']

db_cfg = config.copy()
use_ext = db_cfg.pop('use_ext', False)
db_params = config.copy()
use_ext = db_params.pop('use_ext', False)
use_pool = False

if 'max_connections' in db_cfg:
db_cfg['max_connections'] = int(db_cfg['max_connections'])
use_pool = db_cfg['max_connections'] > 1
if 'max_connections' in db_params:
db_params['max_connections'] = int(db_params['max_connections'])
use_pool = db_params['max_connections'] > 1
if not use_pool:
db_cfg.pop('max_connections')
else:
use_pool = False
db_params.pop('max_connections')

if use_pool:
if use_ext:
Expand All @@ -100,7 +99,7 @@ def setUpModule():
else:
db_cls = peewee_async.PostgresqlDatabase

database.conn = db_cls(**db_cfg)
database.conn = db_cls(**db_params)


class TestModel(peewee.Model):
Expand Down Expand Up @@ -141,8 +140,25 @@ class Meta:
database = database


class PostgresInitTestCase(unittest.TestCase):
def test_deferred_init(self):
db = peewee_async.PooledPostgresqlDatabase(None)
self.assertTrue(db.deferred)

db.init(**db_params)

loop = asyncio.get_event_loop()
loop.run_until_complete(db.connect_async(loop=loop))
# Should not fail connect again
loop.run_until_complete(db.connect_async(loop=loop))
loop.run_until_complete(db.close_async())
# Should not closing connect again
loop.run_until_complete(db.close_async())


class BaseAsyncPostgresTestCase(unittest.TestCase):
db_tables = [TestModel, UUIDTestModel, TestModelAlpha, TestModelBeta, TestModelGamma]
db_tables = [TestModel, UUIDTestModel, TestModelAlpha,
TestModelBeta, TestModelGamma]

@classmethod
def setUpClass(cls, *args, **kwargs):
Expand All @@ -151,17 +167,14 @@ def setUpClass(cls, *args, **kwargs):

# Async connect
cls.loop = asyncio.get_event_loop()
@asyncio.coroutine
def test():
yield from database.connect_async(loop=cls.loop)
cls.loop.run_until_complete(test())
cls.loop.run_until_complete(database.connect_async(loop=cls.loop))

# Clean up after possible errors
for table in reversed(cls.db_tables):
# Clean up after possible errors
table.drop_table(True, cascade=True)

# Create tables with sync connection
for table in cls.db_tables:
# Create table with sync connection
table.create_table()

# Create at least one object per model
Expand All @@ -182,21 +195,17 @@ def test():

cls.gamma_121 = TestModelGamma.create(text='Gamma 1', beta=cls.beta_12)


@classmethod
def tearDownClass(cls, *args, **kwargs):
for table in reversed(cls.db_tables):
# Finally, clean up
for table in reversed(cls.db_tables):
table.drop_table()

# Close database
database.close()
# Async connect
cls.loop = asyncio.get_event_loop()
@asyncio.coroutine
def close():
yield from database.close_async(loop=cls.loop)
cls.loop.run_until_complete(close())

# Async disconnect
cls.loop.run_until_complete(database.close_async())

def run_until_complete(self, coroutine):
result = self.loop.run_until_complete(coroutine)
Expand Down

0 comments on commit f2b4c52

Please sign in to comment.