Skip to content

Commit

Permalink
Merge pull request #75 from CyberROFL/aggregate_rows
Browse files Browse the repository at this point in the history
RowsCursor — fake cursor over awaited rows
  • Loading branch information
rudyryk committed Nov 14, 2017
2 parents 8c638d0 + 29ecc5c commit b34cf00
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 23 deletions.
48 changes: 28 additions & 20 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,8 @@ def prefetch(query, *subqueries):
# NOTE! This is hacky, we perform async `execute()` and substitute result
# to the initial query:

prefetch_result.query._qr = yield from execute(prefetch_result.query)
qr = yield from execute(prefetch_result.query)
prefetch_result.query._qr = list(qr)
prefetch_result.query._dirty = False

for instance in prefetch_result.query._qr:
Expand All @@ -758,6 +759,23 @@ def prefetch(query, *subqueries):
RESULTS_AGGREGATE_MODELS = peewee.RESULTS_AGGREGATE_MODELS


class RowsCursor(object):
def __init__(self, rows, description):
self._rows = rows
self.description = description
self._idx = 0

def fetchone(self):
if self._idx >= len(self._rows):
return None
row = self._rows[self._idx]
self._idx += 1
return row

def close(self):
pass


class AsyncQueryWrapper:
"""Async query results wrapper for async `select()`. Internally uses
results wrapper produced by sync peewee select query.
Expand All @@ -773,19 +791,16 @@ class AsyncQueryWrapper:
def __init__(self, *, cursor=None, query=None):
self._initialized = False
self._cursor = cursor
self._result = []
self._rows = []
self._result_wrapper = self._get_result_wrapper(query)

def __iter__(self):
return iter(self._result)

def __getitem__(self, key):
return self._result[key]
while True:
yield self._result_wrapper.iterate()

def __len__(self):
return len(self._result)
return len(self._rows)

@classmethod
def _get_result_wrapper(self, query):
"""Get result wrapper class.
"""
Expand All @@ -801,26 +816,18 @@ def _get_result_wrapper(self, query):
else:
QRW = db.get_result_wrapper(RESULTS_MODELS)

return QRW(query.model_class, None, query.get_query_meta())
cursor = RowsCursor(self._rows, self._cursor.description)
return QRW(query.model_class, cursor, query.get_query_meta())

@asyncio.coroutine
def fetchone(self):
row = yield from self._cursor.fetchone()

if not row:
self._cursor = None
self._result_wrapper = None
raise GeneratorExit
elif not self._initialized:
self._result_wrapper.initialize(self._cursor.description)
self._initialized = True

obj = self._result_wrapper.process_row(row)
self._result.append(obj)
self._rows.append(row)


class AsyncRawQueryWrapper(AsyncQueryWrapper):
@classmethod
def _get_result_wrapper(self, query):
"""Get raw query result wrapper class.
"""
Expand All @@ -832,7 +839,8 @@ def _get_result_wrapper(self, query):
else:
QRW = db.get_result_wrapper(RESULTS_NAIVE)

return QRW(query.model_class, None, None)
cursor = RowsCursor(self._rows, self._cursor.description)
return QRW(query.model_class, cursor, None)


############
Expand Down
9 changes: 6 additions & 3 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,17 +543,20 @@ def test(objects):

result1 = yield from objects.execute(TestModel.raw(
'select id, text from testmodel'))
self.assertEqual(len(list(result1)), 1)
result1 = list(result1)
self.assertEqual(len(result1), 1)
self.assertTrue(isinstance(result1[0], TestModel))

result2 = yield from objects.execute(TestModel.raw(
'select id, text from testmodel').tuples())
self.assertEqual(len(list(result2)), 1)
result2 = list(result2)
self.assertEqual(len(result2), 1)
self.assertTrue(isinstance(result2[0], tuple))

result3 = yield from objects.execute(TestModel.raw(
'select id, text from testmodel').dicts())
self.assertEqual(len(list(result3)), 1)
result3 = list(result3)
self.assertEqual(len(result3), 1)
self.assertTrue(isinstance(result3[0], dict))

self.run_with_managers(test)
Expand Down

0 comments on commit b34cf00

Please sign in to comment.