Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement BaseQuerySet.batch_size #1426

Merged
merged 3 commits into from
Dec 6, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions mongoengine/queryset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(self, document, collection):
self._limit = None
self._skip = None
self._hint = -1 # Using -1 as None is a valid value for hint
self._batch_size = None
self.only_fields = []
self._max_time_ms = None

Expand Down Expand Up @@ -781,6 +782,19 @@ def hint(self, index=None):
queryset._hint = index
return queryset

def batch_size(self, size):
"""Limit the number of documents returned in a single batch (each
batch requires a round trip to the server).

See http://api.mongodb.com/python/current/api/pymongo/cursor.html#pymongo.cursor.Cursor.batch_size
for details.

:param size: desired size of each batch.
"""
queryset = self.clone()
queryset._batch_size = size
return queryset

def distinct(self, field):
"""Return a list of distinct values for a given field.

Expand Down Expand Up @@ -1467,6 +1481,9 @@ def _cursor(self):
if self._hint != -1:
self._cursor_obj.hint(self._hint)

if self._batch_size is not None:
self._cursor_obj.batch_size(self._batch_size)

return self._cursor_obj

def __deepcopy__(self, memo):
Expand Down
28 changes: 28 additions & 0 deletions tests/queryset/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,34 @@ class B(Document):
query = query.filter(boolfield=True)
self.assertEqual(query.count(), 1)

def test_batch_size(self):
"""Ensure that batch_size works."""
class A(Document):
s = StringField()

A.drop_collection()

for i in range(100):
A.objects.create(s=str(i))

# test iterating over the result set
cnt = 0
for a in A.objects.batch_size(10):
cnt += 1
self.assertEqual(cnt, 100)

# test chaining
qs = A.objects.all()
qs = qs.limit(10).batch_size(20).skip(91)
cnt = 0
for a in qs:
cnt += 1
self.assertEqual(cnt, 9)

# test invalid batch size
qs = A.objects.batch_size(-1)
self.assertRaises(ValueError, lambda: list(qs))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lambda looks like a workaround here. How about:

with self.assertRaises(ValueError):
    list(qs)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gukoff that won't work for Python 2.6 (which we still support). Thanks though, it was a good idea.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What could be done is self.assertRaises(ValueError, list, qs), but personally I don't find that more readable/less hacky.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that won't work for Python 2.6 (which we still support)

Thanks, I didn't know about this.


def test_update_write_concern(self):
"""Test that passing write_concern works"""
self.Person.drop_collection()
Expand Down