Skip to content

Commit

Permalink
Get rid of dirty hack in select() coroutine
Browse files Browse the repository at this point in the history
  • Loading branch information
rudyryk committed Feb 19, 2016
1 parent b161d88 commit f2a3ef9
Showing 1 changed file with 36 additions and 17 deletions.
53 changes: 36 additions & 17 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@
'savepoint',
]

RESULTS_NAIVE = peewee.RESULTS_NAIVE
RESULTS_MODELS = peewee.RESULTS_MODELS
RESULTS_TUPLES = peewee.RESULTS_TUPLES
RESULTS_DICTS = peewee.RESULTS_DICTS
RESULTS_AGGREGATE_MODELS = peewee.RESULTS_AGGREGATE_MODELS


#################
# Async queries #
Expand Down Expand Up @@ -206,24 +212,14 @@ def select(query):
("Error, trying to run select coroutine"
"with wrong query class %s" % str(query))

# Perform *real* async query
query = query.clone()
cursor = yield from _execute_query_async(query)

# Perform *fake* query: we only need a result wrapper
# here, not the query result itself:
query._execute = lambda: None
result_wrapper = query.execute()

# Fetch result
result = AsyncQueryResult(result_wrapper=result_wrapper, cursor=cursor)
result = AsyncQueryWrapper(query)
cursor = yield from result.execute()
try:
while True:
yield from result.fetchone()
except GeneratorExit:
pass

# Release cursor and return
cursor.release()
return result

Expand Down Expand Up @@ -375,7 +371,7 @@ def prefetch(sq, *subqueries):
###################


class AsyncQueryResult:
class AsyncQueryWrapper:
"""Async query results wrapper for async `select()`. Internally uses
results wrapper produced by sync peewee select query.
Expand All @@ -387,11 +383,12 @@ class AsyncQueryResult:
To retrieve results after async fetching just iterate over this class
instance, like you generally iterate over sync results wrapper.
"""
def __init__(self, result_wrapper=None, cursor=None):
self._result = []
def __init__(self, query):
self._initialized = False
self._result_wrapper = result_wrapper
self._cursor = cursor
self._cursor = None
self._query = query
self._result = []
self._result_wrapper = self._get_result_wrapper(query)

def __iter__(self):
return iter(self._result)
Expand All @@ -402,6 +399,28 @@ def __getitem__(self, key):
def __len__(self):
return len(self._result)

@classmethod
def _get_result_wrapper(self, query):
"""Get result wrapper class.
"""
if query._tuples:
QR = query.database.get_result_wrapper(RESULTS_TUPLES)
elif query._dicts:
QR = query.database.get_result_wrapper(RESULTS_DICTS)
elif query._naive or not query._joins or query.verify_naive():
QR = query.database.get_result_wrapper(RESULTS_NAIVE)
elif query._aggregate_rows:
QR = query.database.get_result_wrapper(RESULTS_AGGREGATE_MODELS)
else:
QR = query.database.get_result_wrapper(RESULTS_MODELS)

return QR(query.model_class, None, query.get_query_meta())

@asyncio.coroutine
def execute(self):
self._cursor = yield from _execute_query_async(self._query)
return self._cursor

@asyncio.coroutine
def fetchone(self):
row = yield from self._cursor.fetchone()
Expand Down

0 comments on commit f2a3ef9

Please sign in to comment.