Skip to content

Commit

Permalink
MySQL connection should be closed correctly now, see #26 + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rudyryk committed May 6, 2016
1 parent 6d1ad22 commit 26793e9
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 52 deletions.
101 changes: 49 additions & 52 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def select(query):
except GeneratorExit:
pass

cursor.release()
yield from cursor.release
return result


Expand All @@ -584,7 +584,7 @@ def insert(query):
result = yield from query.database.last_insert_id_async(
cursor, query.model_class)

cursor.release()
yield from cursor.release
return result


Expand All @@ -599,7 +599,7 @@ def update(query):
cursor = yield from _execute_query_async(query)
rowcount = cursor.rowcount

cursor.release()
yield from cursor.release
return rowcount


Expand All @@ -614,7 +614,7 @@ def delete(query):
cursor = yield from _execute_query_async(query)
rowcount = cursor.rowcount

cursor.release()
yield from cursor.release
return rowcount


Expand Down Expand Up @@ -650,7 +650,7 @@ def scalar(query, as_tuple=False):
cursor = yield from _execute_query_async(query)
row = yield from cursor.fetchone()

cursor.release()
yield from cursor.release
if row and not as_tuple:
return row[0]
else:
Expand All @@ -672,7 +672,7 @@ def raw_query(query):
except GeneratorExit:
pass

cursor.release()
yield from cursor.release
return result


Expand Down Expand Up @@ -983,36 +983,35 @@ def connect(self):
**self.connect_kwargs)

@asyncio.coroutine
def cursor(self, conn=None, *args, **kwargs):
"""Get cursor for connection from pool.
def close(self):
"""Terminate all pool connections.
"""
if conn is None:
# Acquire connection with cursor, once cursor is released
# connection is also released to pool:
self.pool.terminate()
yield from self.pool.wait_closed()

@asyncio.coroutine
def cursor(self, conn=None, *args, **kwargs):
"""Get a cursor for the specified transaction connection
or acquire from the pool.
"""
in_transaction = conn is not None
if not conn:
conn = yield from self.acquire()
cursor = yield from conn.cursor(*args, **kwargs)

def release():
cursor.close()
self.pool.release(conn)
cursor.release = release
else:
# Acquire cursor from provided connection, after cursor is
# released connection is NOT released to pool, i.e.
# for handling transactions:

cursor = yield from conn.cursor(*args, **kwargs)
cursor.release = lambda: cursor.close()

cursor = yield from conn.cursor(*args, **kwargs)
# NOTE: `cursor.release` is an awaitable object!
cursor.release = self.release_cursor(
cursor, in_transaction=in_transaction)
return cursor

@asyncio.coroutine
def close(self):
"""Terminate all pool connections.
def release_cursor(self, cursor, in_transaction=False):
"""Release cursor coroutine. Unless in transaction,
the connection is also released back to the pool.
"""
self.pool.terminate()
yield from self.pool.wait_closed()
conn = cursor.connection
cursor.close()
if not in_transaction:
self.pool.release(conn)


class AsyncPostgresqlMixin(AsyncDatabase):
Expand Down Expand Up @@ -1143,37 +1142,35 @@ def connect(self):
connect_timeout=self.timeout,
**self.connect_kwargs)

@asyncio.coroutine
def close(self):
"""Terminate all pool connections.
"""
self.pool.terminate()
yield from self.pool.wait_closed()

@asyncio.coroutine
def cursor(self, conn=None, *args, **kwargs):
"""Get cursor for connection from pool.
"""
if conn is None:
# Acquire connection with cursor, once cursor is released
# connection is also released to pool:

in_transaction = conn is not None
if not conn:
conn = yield from self.acquire()
cursor = yield from conn.cursor(*args, **kwargs)

def release():
cursor.close()
self.pool.release(conn)
cursor.release = release
else:
# Acquire cursor from provided connection, after cursor is
# released connection is NOT released to pool, i.e.
# for handling transactions:

cursor = yield from conn.cursor(*args, **kwargs)
cursor.release = lambda: cursor.close()

cursor = yield from conn.cursor(*args, **kwargs)
# NOTE: `cursor.release` is an awaitable object!
cursor.release = self.release_cursor(
cursor, in_transaction=in_transaction)
return cursor

@asyncio.coroutine
def close(self):
"""Terminate all pool connections.
def release_cursor(self, cursor, in_transaction=False):
"""Release cursor coroutine. Unless in transaction,
the connection is also released back to the pool.
"""
self.pool.terminate()
yield from self.pool.wait_closed()
conn = cursor.connection
yield from cursor.close()
if not in_transaction:
self.pool.release(conn)


class MySQLDatabase(AsyncDatabase, peewee.MySQLDatabase):
Expand Down Expand Up @@ -1395,7 +1392,7 @@ def _run_sql(database, operation, *args, **kwargs):
try:
yield from cursor.execute(operation, *args, **kwargs)
except:
cursor.release()
yield from cursor.release
raise

return cursor
Expand Down
14 changes: 14 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,20 @@ def test(objects):

self.run_with_managers(test)

def test_many_requests(self):
@asyncio.coroutine
def test(objects):
max_connections = getattr(objects.database, 'max_connections', 0)
text = "Test %s" % uuid.uuid4()
obj = yield from objects.create(TestModel, text=text)
n = 2 * max_connections # number of requests
done, not_done = yield from asyncio.wait(
[objects.get(TestModel, id=obj.id) for _ in range(n)],
loop=self.loop)
self.assertEqual(len(done), n)

self.run_with_managers(test)

def test_create_obj(self):
@asyncio.coroutine
def test(objects):
Expand Down

0 comments on commit 26793e9

Please sign in to comment.