Skip to content

Commit

Permalink
Add proper support for RawQuery
Browse files Browse the repository at this point in the history
  • Loading branch information
rudyryk committed Apr 18, 2016
1 parent acb8da5 commit 37a49a4
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 22 deletions.
64 changes: 42 additions & 22 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class User(peewee.Model):
user1 = await objects.get(User, id=user0.id)
user2 = await objects.get(User, username='test')
# All should be the same
print(user1, user2, user3)
print(user1.id, user2.id, user3.id)
If you don't pass database to constructor, you should define
`database` as a class member like that::
Expand All @@ -90,8 +90,6 @@ class MyManager(Manager):
objects = MyManager()
Can also handle multiple databases for single model.
Just create `Manager` instance per database.
"""
database = None

Expand Down Expand Up @@ -307,12 +305,13 @@ def _swap_database(self, query):
database, it's **WRONG AND DANGEROUS**, so assertion is raised.
"""
model = query.model_class
if model._meta.database == _AutoDatabase:
if model._meta.database == self.database:
return query
elif model._meta.database == _AutoDatabase:
# **Experimental** database swapping!
query = query.clone()
query.database = self.database
return query
elif model._meta.database == self.database:
return query
else:
assert False, ("Error, models's database and manager's "
"database are different: %s" % model)
Expand Down Expand Up @@ -354,16 +353,17 @@ def execute(query):
``Model.update()`` etc.
:return: result depends on query type, it's the same as for sync ``query.execute()``
"""
if isinstance(query, peewee.UpdateQuery):
if isinstance(query, peewee.SelectQuery):
coroutine = select
elif isinstance(query, peewee.UpdateQuery):
coroutine = update
elif isinstance(query, peewee.InsertQuery):
coroutine = insert
elif isinstance(query, peewee.DeleteQuery):
coroutine = delete
elif isinstance(query, peewee.RawQuery):
coroutine = raw_query
else:
coroutine = select
coroutine = raw_query

return (yield from coroutine(query))


Expand Down Expand Up @@ -484,9 +484,6 @@ def update_object(obj, only=None):
@asyncio.coroutine
def select(query):
"""Perform SELECT query asynchronously.
NOTE! It relies on internal peewee logic for generating
results from queries and well, a bit hacky.
"""
assert isinstance(query, peewee.SelectQuery),\
("Error, trying to run select coroutine"
Expand Down Expand Up @@ -601,9 +598,17 @@ def raw_query(query):
("Error, trying to run delete coroutine"
"with wrong query class %s" % str(query))

cursor = yield from _execute_query_async(query)
message = cursor.statusmessage
return message
result = AsyncRawQueryWrapper(query)
cursor = yield from result.execute()

try:
while True:
yield from result.fetchone()
except GeneratorExit:
pass

cursor.release()
return result


@asyncio.coroutine
Expand Down Expand Up @@ -694,17 +699,17 @@ def _get_result_wrapper(self, query):
"""Get result wrapper class.
"""
if query._tuples:
QR = query.database.get_result_wrapper(RESULTS_TUPLES)
QRW = query.database.get_result_wrapper(RESULTS_TUPLES)
elif query._dicts:
QR = query.database.get_result_wrapper(RESULTS_DICTS)
QRW = query.database.get_result_wrapper(RESULTS_DICTS)
elif query._naive or not query._joins or query.verify_naive():
QR = query.database.get_result_wrapper(RESULTS_NAIVE)
QRW = query.database.get_result_wrapper(RESULTS_NAIVE)
elif query._aggregate_rows:
QR = query.database.get_result_wrapper(RESULTS_AGGREGATE_MODELS)
QRW = query.database.get_result_wrapper(RESULTS_AGGREGATE_MODELS)
else:
QR = query.database.get_result_wrapper(RESULTS_MODELS)
QRW = query.database.get_result_wrapper(RESULTS_MODELS)

return QR(query.model_class, None, query.get_query_meta())
return QRW(query.model_class, None, query.get_query_meta())

@asyncio.coroutine
def execute(self):
Expand All @@ -727,6 +732,21 @@ def fetchone(self):
self._result.append(obj)


class AsyncRawQueryWrapper(AsyncQueryWrapper):
@classmethod
def _get_result_wrapper(self, query):
"""Get raw query result wrapper class.
"""
if query._tuples:
QRW = query.database.get_result_wrapper(RESULTS_TUPLES)
elif query._dicts:
QRW = query.database.get_result_wrapper(RESULTS_DICTS)
else:
QRW = query.database.get_result_wrapper(RESULTS_NAIVE)

return QRW(query.model_class, None, None)


##############
# PostgreSQL #
##############
Expand Down
23 changes: 23 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,29 @@ def test(objects):

self.run_with_managers(test)

def test_raw_query(self):
@asyncio.coroutine
def test(objects):
text = "Test %s" % uuid.uuid4()
yield from objects.create(TestModel, text=text)

result1 = yield from objects.execute(TestModel.raw(
'select id, text from testmodel'))
self.assertEqual(len(list(result1)), 1)
self.assertTrue(isinstance(result1[0], TestModel))

result2 = yield from objects.execute(TestModel.raw(
'select id, text from testmodel').tuples())
self.assertEqual(len(list(result2)), 1)
self.assertTrue(isinstance(result2[0], tuple))

result3 = yield from objects.execute(TestModel.raw(
'select id, text from testmodel').dicts())
self.assertEqual(len(list(result3)), 1)
self.assertTrue(isinstance(result3[0], dict))

self.run_with_managers(test)

def test_select_many_objects(self):
@asyncio.coroutine
def test(objects):
Expand Down

0 comments on commit 37a49a4

Please sign in to comment.